[15830] | 1 | #region License Information
|
---|
| 2 | /* HeuristicLab
|
---|
[17180] | 3 | * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
|
---|
[15830] | 4 | *
|
---|
| 5 | * This file is part of HeuristicLab.
|
---|
| 6 | *
|
---|
| 7 | * HeuristicLab is free software: you can redistribute it and/or modify
|
---|
| 8 | * it under the terms of the GNU General Public License as published by
|
---|
| 9 | * the Free Software Foundation, either version 3 of the License, or
|
---|
| 10 | * (at your option) any later version.
|
---|
| 11 | *
|
---|
| 12 | * HeuristicLab is distributed in the hope that it will be useful,
|
---|
| 13 | * but WITHOUT ANY WARRANTY; without even the implied warranty of
|
---|
| 14 | * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
---|
| 15 | * GNU General Public License for more details.
|
---|
| 16 | *
|
---|
| 17 | * You should have received a copy of the GNU General Public License
|
---|
| 18 | * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
|
---|
| 19 | */
|
---|
| 20 | #endregion
|
---|
| 21 |
|
---|
| 22 | using System;
|
---|
| 23 | using System.Collections.Generic;
|
---|
| 24 | using System.Linq;
|
---|
| 25 | using System.Threading;
|
---|
| 26 | using HeuristicLab.Common;
|
---|
| 27 | using HeuristicLab.Core;
|
---|
| 28 | using HeuristicLab.Data;
|
---|
| 29 | using HeuristicLab.Parameters;
|
---|
| 30 | using HeuristicLab.Problems.DataAnalysis;
|
---|
[16847] | 31 | using HEAL.Attic;
|
---|
[15830] | 32 |
|
---|
| 33 | namespace HeuristicLab.Algorithms.DataAnalysis {
|
---|
[16847] | 34 | [StorableType("B643D965-D13F-415D-8589-3F3527460347")]
|
---|
[15830] | 35 | public class ComplexityPruning : ParameterizedNamedItem, IPruning {
|
---|
| 36 | public const string PruningStateVariableName = "PruningState";
|
---|
| 37 |
|
---|
[16847] | 38 | public const string PruningStrengthParameterName = "PruningStrength";
|
---|
| 39 | public const string PruningDecayParameterName = "PruningDecay";
|
---|
| 40 | public const string FastPruningParameterName = "FastPruning";
|
---|
[15830] | 41 |
|
---|
| 42 | public IFixedValueParameter<DoubleValue> PruningStrengthParameter {
|
---|
| 43 | get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningStrengthParameterName]; }
|
---|
| 44 | }
|
---|
| 45 | public IFixedValueParameter<DoubleValue> PruningDecayParameter {
|
---|
| 46 | get { return (IFixedValueParameter<DoubleValue>)Parameters[PruningDecayParameterName]; }
|
---|
| 47 | }
|
---|
| 48 | public IFixedValueParameter<BoolValue> FastPruningParameter {
|
---|
| 49 | get { return (IFixedValueParameter<BoolValue>)Parameters[FastPruningParameterName]; }
|
---|
| 50 | }
|
---|
| 51 |
|
---|
| 52 | public double PruningStrength {
|
---|
| 53 | get { return PruningStrengthParameter.Value.Value; }
|
---|
[16847] | 54 | set { PruningStrengthParameter.Value.Value = value; }
|
---|
[15830] | 55 | }
|
---|
| 56 | public double PruningDecay {
|
---|
| 57 | get { return PruningDecayParameter.Value.Value; }
|
---|
[16847] | 58 | set { PruningDecayParameter.Value.Value = value; }
|
---|
[15830] | 59 | }
|
---|
| 60 | public bool FastPruning {
|
---|
| 61 | get { return FastPruningParameter.Value.Value; }
|
---|
[16847] | 62 | set { FastPruningParameter.Value.Value = value; }
|
---|
[15830] | 63 | }
|
---|
| 64 |
|
---|
| 65 | #region Constructors & Cloning
|
---|
| 66 | [StorableConstructor]
|
---|
[16847] | 67 | protected ComplexityPruning(StorableConstructorFlag _) : base(_) { }
|
---|
[15830] | 68 | protected ComplexityPruning(ComplexityPruning original, Cloner cloner) : base(original, cloner) { }
|
---|
| 69 | public ComplexityPruning() {
|
---|
[16847] | 70 | Parameters.Add(new FixedValueParameter<DoubleValue>(PruningStrengthParameterName, "The strength of the pruning. Higher values force the algorithm to create simpler models (default=2.0).", new DoubleValue(2.0)));
|
---|
| 71 | Parameters.Add(new FixedValueParameter<DoubleValue>(PruningDecayParameterName, "Pruning decay allows nodes higher up in the tree to be more stable (default=1.0).", new DoubleValue(1.0)));
|
---|
| 72 | Parameters.Add(new FixedValueParameter<BoolValue>(FastPruningParameterName, "Accelerate pruning by using linear models instead of leaf models (default=true).", new BoolValue(true)));
|
---|
[15830] | 73 | }
|
---|
| 74 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 75 | return new ComplexityPruning(this, cloner);
|
---|
| 76 | }
|
---|
| 77 | #endregion
|
---|
| 78 |
|
---|
| 79 | #region IPruning
|
---|
| 80 | public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) {
|
---|
| 81 | return (FastPruning ? new LinearLeaf() : leafModel).MinLeafSize(pd);
|
---|
| 82 | }
|
---|
| 83 | public void Initialize(IScope states) {
|
---|
| 84 | states.Variables.Add(new Variable(PruningStateVariableName, new PruningState()));
|
---|
| 85 | }
|
---|
| 86 |
|
---|
| 87 | public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, IScope statescope, CancellationToken cancellationToken) {
|
---|
[17080] | 88 | var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value;
|
---|
[15830] | 89 | var state = (PruningState)statescope.Variables[PruningStateVariableName].Value;
|
---|
| 90 |
|
---|
| 91 | var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel;
|
---|
| 92 | if (state.Code <= 1) {
|
---|
| 93 | InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken);
|
---|
| 94 | cancellationToken.ThrowIfCancellationRequested();
|
---|
| 95 | }
|
---|
| 96 | if (state.Code <= 2) {
|
---|
| 97 | AssignPruningThresholds(treeModel, state, PruningDecay);
|
---|
| 98 | cancellationToken.ThrowIfCancellationRequested();
|
---|
| 99 | }
|
---|
| 100 | if (state.Code <= 3) {
|
---|
| 101 | UpdateThreshold(treeModel, state);
|
---|
| 102 | cancellationToken.ThrowIfCancellationRequested();
|
---|
| 103 | }
|
---|
| 104 | if (state.Code <= 4) {
|
---|
| 105 | Prune(treeModel, state, PruningStrength);
|
---|
| 106 | cancellationToken.ThrowIfCancellationRequested();
|
---|
| 107 | }
|
---|
| 108 |
|
---|
| 109 | state.Code = 5;
|
---|
| 110 | }
|
---|
| 111 | #endregion
|
---|
| 112 |
|
---|
| 113 | private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
|
---|
| 114 | if (state.Code == 0) {
|
---|
| 115 | state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data);
|
---|
| 116 | state.Code = 1;
|
---|
| 117 | }
|
---|
| 118 | while (state.nodeQueue.Count != 0) {
|
---|
| 119 | cancellationToken.ThrowIfCancellationRequested();
|
---|
| 120 | var n = state.nodeQueue.Peek();
|
---|
| 121 | var training = state.trainingRowsQueue.Peek();
|
---|
| 122 | var pruning = state.pruningRowsQueue.Peek();
|
---|
| 123 | BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken);
|
---|
| 124 | state.nodeQueue.Dequeue();
|
---|
| 125 | state.trainingRowsQueue.Dequeue();
|
---|
| 126 | state.pruningRowsQueue.Dequeue();
|
---|
| 127 | }
|
---|
| 128 | }
|
---|
| 129 |
|
---|
| 130 | private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) {
|
---|
| 131 | if (state.Code == 1) {
|
---|
| 132 | state.FillBottomUp(tree);
|
---|
| 133 | state.Code = 2;
|
---|
| 134 | }
|
---|
| 135 | while (state.nodeQueue.Count != 0) {
|
---|
| 136 | var n = state.nodeQueue.Dequeue();
|
---|
| 137 | if (n.IsLeaf) continue;
|
---|
| 138 | n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay);
|
---|
| 139 | }
|
---|
| 140 | }
|
---|
| 141 |
|
---|
| 142 | private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) {
|
---|
| 143 | if (state.Code == 2) {
|
---|
| 144 | state.FillTopDown(tree);
|
---|
| 145 | state.Code = 3;
|
---|
| 146 | }
|
---|
| 147 | while (state.nodeQueue.Count != 0) {
|
---|
| 148 | var n = state.nodeQueue.Dequeue();
|
---|
| 149 | if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) continue;
|
---|
| 150 | n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength);
|
---|
| 151 | }
|
---|
| 152 | }
|
---|
| 153 |
|
---|
| 154 | private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) {
|
---|
| 155 | if (state.Code == 3) {
|
---|
| 156 | state.FillTopDown(tree);
|
---|
| 157 | state.Code = 4;
|
---|
| 158 | }
|
---|
| 159 | while (state.nodeQueue.Count != 0) {
|
---|
| 160 | var n = state.nodeQueue.Dequeue();
|
---|
| 161 | if (n.IsLeaf || pruningStrength <= n.PruningStrength) continue;
|
---|
| 162 | n.ToLeaf();
|
---|
| 163 | }
|
---|
| 164 | }
|
---|
| 165 |
|
---|
| 166 | private static void BuildPruningModel(RegressionNodeModel regressionNode, ILeafModel leaf, IReadOnlyList<int> trainingRows, IReadOnlyList<int> pruningRows, PruningState state, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) {
|
---|
| 167 | //create regressionProblemdata from pruning data
|
---|
[16847] | 168 | var vars = regressionTreeParams.AllowedInputVariables.Concat(new[] { regressionTreeParams.TargetVariable }).ToArray();
|
---|
[15830] | 169 | var reducedData = new Dataset(vars, vars.Select(x => regressionTreeParams.Data.GetDoubleValues(x, pruningRows).ToList()));
|
---|
| 170 | var pd = new RegressionProblemData(reducedData, regressionTreeParams.AllowedInputVariables, regressionTreeParams.TargetVariable);
|
---|
| 171 | pd.TrainingPartition.Start = pd.TrainingPartition.End = pd.TestPartition.Start = 0;
|
---|
| 172 | pd.TestPartition.End = reducedData.Rows;
|
---|
| 173 |
|
---|
| 174 | //build pruning model
|
---|
| 175 | int numModelParams;
|
---|
| 176 | var model = leaf.BuildModel(trainingRows, regressionTreeParams, cancellationToken, out numModelParams);
|
---|
| 177 |
|
---|
| 178 | //record error and complexities
|
---|
| 179 | var rmsModel = model.CreateRegressionSolution(pd).TestRootMeanSquaredError;
|
---|
| 180 | state.pruningSizes.Add(regressionNode, pruningRows.Count);
|
---|
| 181 | state.modelErrors.Add(regressionNode, rmsModel);
|
---|
| 182 | state.modelComplexities.Add(regressionNode, numModelParams);
|
---|
[16847] | 183 | if (regressionNode.IsLeaf) { state.nodeComplexities[regressionNode] = state.modelComplexities[regressionNode]; } else { state.nodeComplexities.Add(regressionNode, state.nodeComplexities[regressionNode.Left] + state.nodeComplexities[regressionNode.Right] + 1); }
|
---|
[15830] | 184 | }
|
---|
| 185 |
|
---|
| 186 | private static double PruningThreshold(double noIntances, double modelParams, double nodeParams, double modelError, double nodeError, double w) {
|
---|
| 187 | var res = modelError / nodeError;
|
---|
| 188 | if (modelError.IsAlmost(nodeError)) res = 1.0;
|
---|
| 189 | res /= Math.Pow((nodeParams + noIntances) / (2 * (modelParams + noIntances)), w);
|
---|
| 190 | return res;
|
---|
| 191 | }
|
---|
| 192 |
|
---|
| 193 | private static double SubtreeError(RegressionNodeModel regressionNode, IDictionary<RegressionNodeModel, int> pruningSizes,
|
---|
| 194 | IDictionary<RegressionNodeModel, double> modelErrors) {
|
---|
| 195 | if (regressionNode.IsLeaf) return modelErrors[regressionNode];
|
---|
| 196 | var errorL = SubtreeError(regressionNode.Left, pruningSizes, modelErrors);
|
---|
| 197 | var errorR = SubtreeError(regressionNode.Right, pruningSizes, modelErrors);
|
---|
| 198 | errorL = errorL * errorL * pruningSizes[regressionNode.Left];
|
---|
| 199 | errorR = errorR * errorR * pruningSizes[regressionNode.Right];
|
---|
| 200 | return Math.Sqrt((errorR + errorL) / pruningSizes[regressionNode]);
|
---|
| 201 | }
|
---|
| 202 |
|
---|
[16847] | 203 | [StorableType("EAD60C7E-2C58-45C4-9697-6F735F518CFD")]
|
---|
[15830] | 204 | public class PruningState : Item {
|
---|
| 205 | [Storable]
|
---|
[17139] | 206 | public IDictionary<RegressionNodeModel, int> modelComplexities;
|
---|
[15830] | 207 | [Storable]
|
---|
[17139] | 208 | public IDictionary<RegressionNodeModel, int> nodeComplexities;
|
---|
[15830] | 209 | [Storable]
|
---|
[17139] | 210 | public IDictionary<RegressionNodeModel, int> pruningSizes;
|
---|
[15830] | 211 | [Storable]
|
---|
[17139] | 212 | public IDictionary<RegressionNodeModel, double> modelErrors;
|
---|
[15830] | 213 |
|
---|
| 214 | [Storable]
|
---|
[17139] | 215 | private RegressionNodeModel[] storableNodeQueue { get { return nodeQueue.ToArray(); } set { nodeQueue = new Queue<RegressionNodeModel>(value); } }
|
---|
| 216 | public Queue<RegressionNodeModel> nodeQueue;
|
---|
[15830] | 217 | [Storable]
|
---|
[17139] | 218 | private IReadOnlyList<int>[] storabletrainingRowsQueue { get { return trainingRowsQueue.ToArray(); } set { trainingRowsQueue = new Queue<IReadOnlyList<int>>(value); } }
|
---|
| 219 | public Queue<IReadOnlyList<int>> trainingRowsQueue;
|
---|
[15830] | 220 | [Storable]
|
---|
[17139] | 221 | private IReadOnlyList<int>[] storablepruningRowsQueue { get { return pruningRowsQueue.ToArray(); } set { pruningRowsQueue = new Queue<IReadOnlyList<int>>(value); } }
|
---|
| 222 | public Queue<IReadOnlyList<int>> pruningRowsQueue;
|
---|
[15830] | 223 |
|
---|
| 224 | //State.Code values denote the current action (for pausing)
|
---|
| 225 | //0...nothing has been done;
|
---|
| 226 | //1...building Models;
|
---|
| 227 | //2...assigning threshold
|
---|
| 228 | //3...adjusting threshold
|
---|
| 229 | //4...pruning
|
---|
| 230 | //5...finished
|
---|
| 231 | [Storable]
|
---|
| 232 | public int Code = 0;
|
---|
| 233 |
|
---|
| 234 | #region HLConstructors & Cloning
|
---|
| 235 | [StorableConstructor]
|
---|
[16847] | 236 | protected PruningState(StorableConstructorFlag _) : base(_) { }
|
---|
[15830] | 237 | protected PruningState(PruningState original, Cloner cloner) : base(original, cloner) {
|
---|
| 238 | modelComplexities = original.modelComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
|
---|
| 239 | nodeComplexities = original.nodeComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
|
---|
| 240 | pruningSizes = original.pruningSizes.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
|
---|
| 241 | modelErrors = original.modelErrors.ToDictionary(x => cloner.Clone(x.Key), x => x.Value);
|
---|
| 242 | nodeQueue = new Queue<RegressionNodeModel>(original.nodeQueue.Select(cloner.Clone));
|
---|
| 243 | trainingRowsQueue = new Queue<IReadOnlyList<int>>(original.trainingRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
|
---|
| 244 | pruningRowsQueue = new Queue<IReadOnlyList<int>>(original.pruningRowsQueue.Select(x => (IReadOnlyList<int>)x.ToArray()));
|
---|
| 245 | Code = original.Code;
|
---|
| 246 | }
|
---|
[17139] | 247 | public PruningState() {
|
---|
| 248 | modelComplexities = new Dictionary<RegressionNodeModel, int>();
|
---|
| 249 | nodeComplexities = new Dictionary<RegressionNodeModel, int>();
|
---|
| 250 | pruningSizes = new Dictionary<RegressionNodeModel, int>();
|
---|
| 251 | modelErrors = new Dictionary<RegressionNodeModel, double>();
|
---|
| 252 | nodeQueue = new Queue<RegressionNodeModel>();
|
---|
| 253 | trainingRowsQueue = new Queue<IReadOnlyList<int>>();
|
---|
| 254 | pruningRowsQueue = new Queue<IReadOnlyList<int>>();
|
---|
| 255 | }
|
---|
[15830] | 256 | public override IDeepCloneable Clone(Cloner cloner) {
|
---|
| 257 | return new PruningState(this, cloner);
|
---|
| 258 | }
|
---|
| 259 | #endregion
|
---|
| 260 |
|
---|
| 261 | public void FillTopDown(RegressionNodeTreeModel tree) {
|
---|
| 262 | var helperQueue = new Queue<RegressionNodeModel>();
|
---|
| 263 | nodeQueue.Clear();
|
---|
| 264 |
|
---|
| 265 | helperQueue.Enqueue(tree.Root);
|
---|
| 266 | nodeQueue.Enqueue(tree.Root);
|
---|
| 267 |
|
---|
| 268 | while (helperQueue.Count != 0) {
|
---|
| 269 | var n = helperQueue.Dequeue();
|
---|
| 270 | if (n.IsLeaf) continue;
|
---|
| 271 | helperQueue.Enqueue(n.Left);
|
---|
| 272 | helperQueue.Enqueue(n.Right);
|
---|
| 273 | nodeQueue.Enqueue(n.Left);
|
---|
| 274 | nodeQueue.Enqueue(n.Right);
|
---|
| 275 | }
|
---|
| 276 | }
|
---|
| 277 |
|
---|
| 278 | public void FillTopDown(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
|
---|
| 279 | var helperQueue = new Queue<RegressionNodeModel>();
|
---|
| 280 | var trainingHelperQueue = new Queue<IReadOnlyList<int>>();
|
---|
| 281 | var pruningHelperQueue = new Queue<IReadOnlyList<int>>();
|
---|
| 282 | nodeQueue.Clear();
|
---|
| 283 | trainingRowsQueue.Clear();
|
---|
| 284 | pruningRowsQueue.Clear();
|
---|
| 285 |
|
---|
| 286 | helperQueue.Enqueue(tree.Root);
|
---|
| 287 |
|
---|
| 288 | trainingHelperQueue.Enqueue(trainingRows);
|
---|
| 289 | pruningHelperQueue.Enqueue(pruningRows);
|
---|
| 290 |
|
---|
| 291 | nodeQueue.Enqueue(tree.Root);
|
---|
| 292 | trainingRowsQueue.Enqueue(trainingRows);
|
---|
| 293 | pruningRowsQueue.Enqueue(pruningRows);
|
---|
| 294 |
|
---|
| 295 |
|
---|
| 296 | while (helperQueue.Count != 0) {
|
---|
| 297 | var n = helperQueue.Dequeue();
|
---|
| 298 | var p = pruningHelperQueue.Dequeue();
|
---|
| 299 | var t = trainingHelperQueue.Dequeue();
|
---|
| 300 | if (n.IsLeaf) continue;
|
---|
| 301 |
|
---|
| 302 | IReadOnlyList<int> leftPruning, rightPruning;
|
---|
| 303 | RegressionTreeUtilities.SplitRows(p, data, n.SplitAttribute, n.SplitValue, out leftPruning, out rightPruning);
|
---|
| 304 | IReadOnlyList<int> leftTraining, rightTraining;
|
---|
| 305 | RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining);
|
---|
| 306 |
|
---|
| 307 | helperQueue.Enqueue(n.Left);
|
---|
| 308 | helperQueue.Enqueue(n.Right);
|
---|
| 309 | trainingHelperQueue.Enqueue(leftTraining);
|
---|
| 310 | trainingHelperQueue.Enqueue(rightTraining);
|
---|
| 311 | pruningHelperQueue.Enqueue(leftPruning);
|
---|
| 312 | pruningHelperQueue.Enqueue(rightPruning);
|
---|
| 313 |
|
---|
| 314 | nodeQueue.Enqueue(n.Left);
|
---|
| 315 | nodeQueue.Enqueue(n.Right);
|
---|
| 316 | trainingRowsQueue.Enqueue(leftTraining);
|
---|
| 317 | trainingRowsQueue.Enqueue(rightTraining);
|
---|
| 318 | pruningRowsQueue.Enqueue(leftPruning);
|
---|
| 319 | pruningRowsQueue.Enqueue(rightPruning);
|
---|
| 320 | }
|
---|
| 321 | }
|
---|
| 322 |
|
---|
| 323 | public void FillBottomUp(RegressionNodeTreeModel tree) {
|
---|
| 324 | FillTopDown(tree);
|
---|
| 325 | nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
|
---|
| 326 | }
|
---|
| 327 |
|
---|
| 328 | public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList<int> pruningRows, IReadOnlyList<int> trainingRows, IDataset data) {
|
---|
| 329 | FillTopDown(tree, pruningRows, trainingRows, data);
|
---|
| 330 | nodeQueue = new Queue<RegressionNodeModel>(nodeQueue.Reverse());
|
---|
| 331 | pruningRowsQueue = new Queue<IReadOnlyList<int>>(pruningRowsQueue.Reverse());
|
---|
| 332 | trainingRowsQueue = new Queue<IReadOnlyList<int>>(trainingRowsQueue.Reverse());
|
---|
| 333 | }
|
---|
| 334 | }
|
---|
| 335 | }
|
---|
| 336 | } |
---|