#region License Information /* HeuristicLab * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System; using System.Collections.Generic; using System.Linq; using System.Threading; using HeuristicLab.Common; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Parameters; using HeuristicLab.Problems.DataAnalysis; using HEAL.Attic; namespace HeuristicLab.Algorithms.DataAnalysis { [StorableType("B643D965-D13F-415D-8589-3F3527460347")] public class ComplexityPruning : ParameterizedNamedItem, IPruning { public const string PruningStateVariableName = "PruningState"; public const string PruningStrengthParameterName = "PruningStrength"; public const string PruningDecayParameterName = "PruningDecay"; public const string FastPruningParameterName = "FastPruning"; public IFixedValueParameter PruningStrengthParameter { get { return (IFixedValueParameter)Parameters[PruningStrengthParameterName]; } } public IFixedValueParameter PruningDecayParameter { get { return (IFixedValueParameter)Parameters[PruningDecayParameterName]; } } public IFixedValueParameter FastPruningParameter { get { return (IFixedValueParameter)Parameters[FastPruningParameterName]; } } public double PruningStrength { get { return PruningStrengthParameter.Value.Value; } set { PruningStrengthParameter.Value.Value = value; } } public double PruningDecay { get { return PruningDecayParameter.Value.Value; } set { PruningDecayParameter.Value.Value = value; } } public bool FastPruning { get { return FastPruningParameter.Value.Value; } set { FastPruningParameter.Value.Value = value; } } #region Constructors & Cloning [StorableConstructor] protected ComplexityPruning(StorableConstructorFlag _) : base(_) { } protected ComplexityPruning(ComplexityPruning original, Cloner cloner) : base(original, cloner) { } public ComplexityPruning() { Parameters.Add(new FixedValueParameter(PruningStrengthParameterName, "The strength of the pruning. Higher values force the algorithm to create simpler models (default=2.0).", new DoubleValue(2.0))); Parameters.Add(new FixedValueParameter(PruningDecayParameterName, "Pruning decay allows nodes higher up in the tree to be more stable (default=1.0).", new DoubleValue(1.0))); Parameters.Add(new FixedValueParameter(FastPruningParameterName, "Accelerate pruning by using linear models instead of leaf models (default=true).", new BoolValue(true))); } public override IDeepCloneable Clone(Cloner cloner) { return new ComplexityPruning(this, cloner); } #endregion #region IPruning public int MinLeafSize(IRegressionProblemData pd, ILeafModel leafModel) { return (FastPruning ? new LinearLeaf() : leafModel).MinLeafSize(pd); } public void Initialize(IScope states) { states.Variables.Add(new Variable(PruningStateVariableName, new PruningState())); } public void Prune(RegressionNodeTreeModel treeModel, IReadOnlyList trainingRows, IReadOnlyList pruningRows, IScope statescope, CancellationToken cancellationToken) { var regressionTreeParams = (RegressionTreeParameters)statescope.Variables[DecisionTreeRegression.RegressionTreeParameterVariableName].Value; var state = (PruningState)statescope.Variables[PruningStateVariableName].Value; var leaf = FastPruning ? new LinearLeaf() : regressionTreeParams.LeafModel; if (state.Code <= 1) { InstallModels(treeModel, state, trainingRows, pruningRows, leaf, regressionTreeParams, cancellationToken); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 2) { AssignPruningThresholds(treeModel, state, PruningDecay); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 3) { UpdateThreshold(treeModel, state); cancellationToken.ThrowIfCancellationRequested(); } if (state.Code <= 4) { Prune(treeModel, state, PruningStrength); cancellationToken.ThrowIfCancellationRequested(); } state.Code = 5; } #endregion private static void InstallModels(RegressionNodeTreeModel tree, PruningState state, IReadOnlyList trainingRows, IReadOnlyList pruningRows, ILeafModel leaf, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) { if (state.Code == 0) { state.FillBottomUp(tree, trainingRows, pruningRows, regressionTreeParams.Data); state.Code = 1; } while (state.nodeQueue.Count != 0) { cancellationToken.ThrowIfCancellationRequested(); var n = state.nodeQueue.Peek(); var training = state.trainingRowsQueue.Peek(); var pruning = state.pruningRowsQueue.Peek(); BuildPruningModel(n, leaf, training, pruning, state, regressionTreeParams, cancellationToken); state.nodeQueue.Dequeue(); state.trainingRowsQueue.Dequeue(); state.pruningRowsQueue.Dequeue(); } } private static void AssignPruningThresholds(RegressionNodeTreeModel tree, PruningState state, double pruningDecay) { if (state.Code == 1) { state.FillBottomUp(tree); state.Code = 2; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf) continue; n.PruningStrength = PruningThreshold(state.pruningSizes[n], state.modelComplexities[n], state.nodeComplexities[n], state.modelErrors[n], SubtreeError(n, state.pruningSizes, state.modelErrors), pruningDecay); } } private static void UpdateThreshold(RegressionNodeTreeModel tree, PruningState state) { if (state.Code == 2) { state.FillTopDown(tree); state.Code = 3; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf || n.Parent == null || double.IsNaN(n.Parent.PruningStrength)) continue; n.PruningStrength = Math.Min(n.PruningStrength, n.Parent.PruningStrength); } } private static void Prune(RegressionNodeTreeModel tree, PruningState state, double pruningStrength) { if (state.Code == 3) { state.FillTopDown(tree); state.Code = 4; } while (state.nodeQueue.Count != 0) { var n = state.nodeQueue.Dequeue(); if (n.IsLeaf || pruningStrength <= n.PruningStrength) continue; n.ToLeaf(); } } private static void BuildPruningModel(RegressionNodeModel regressionNode, ILeafModel leaf, IReadOnlyList trainingRows, IReadOnlyList pruningRows, PruningState state, RegressionTreeParameters regressionTreeParams, CancellationToken cancellationToken) { //create regressionProblemdata from pruning data var vars = regressionTreeParams.AllowedInputVariables.Concat(new[] { regressionTreeParams.TargetVariable }).ToArray(); var reducedData = new Dataset(vars, vars.Select(x => regressionTreeParams.Data.GetDoubleValues(x, pruningRows).ToList())); var pd = new RegressionProblemData(reducedData, regressionTreeParams.AllowedInputVariables, regressionTreeParams.TargetVariable); pd.TrainingPartition.Start = pd.TrainingPartition.End = pd.TestPartition.Start = 0; pd.TestPartition.End = reducedData.Rows; //build pruning model int numModelParams; var model = leaf.BuildModel(trainingRows, regressionTreeParams, cancellationToken, out numModelParams); //record error and complexities var rmsModel = model.CreateRegressionSolution(pd).TestRootMeanSquaredError; state.pruningSizes.Add(regressionNode, pruningRows.Count); state.modelErrors.Add(regressionNode, rmsModel); state.modelComplexities.Add(regressionNode, numModelParams); if (regressionNode.IsLeaf) { state.nodeComplexities[regressionNode] = state.modelComplexities[regressionNode]; } else { state.nodeComplexities.Add(regressionNode, state.nodeComplexities[regressionNode.Left] + state.nodeComplexities[regressionNode.Right] + 1); } } private static double PruningThreshold(double noIntances, double modelParams, double nodeParams, double modelError, double nodeError, double w) { var res = modelError / nodeError; if (modelError.IsAlmost(nodeError)) res = 1.0; res /= Math.Pow((nodeParams + noIntances) / (2 * (modelParams + noIntances)), w); return res; } private static double SubtreeError(RegressionNodeModel regressionNode, IDictionary pruningSizes, IDictionary modelErrors) { if (regressionNode.IsLeaf) return modelErrors[regressionNode]; var errorL = SubtreeError(regressionNode.Left, pruningSizes, modelErrors); var errorR = SubtreeError(regressionNode.Right, pruningSizes, modelErrors); errorL = errorL * errorL * pruningSizes[regressionNode.Left]; errorR = errorR * errorR * pruningSizes[regressionNode.Right]; return Math.Sqrt((errorR + errorL) / pruningSizes[regressionNode]); } [StorableType("EAD60C7E-2C58-45C4-9697-6F735F518CFD")] public class PruningState : Item { [Storable] public IDictionary modelComplexities; [Storable] public IDictionary nodeComplexities; [Storable] public IDictionary pruningSizes; [Storable] public IDictionary modelErrors; [Storable] private RegressionNodeModel[] storableNodeQueue { get { return nodeQueue.ToArray(); } set { nodeQueue = new Queue(value); } } public Queue nodeQueue; [Storable] private IReadOnlyList[] storabletrainingRowsQueue { get { return trainingRowsQueue.ToArray(); } set { trainingRowsQueue = new Queue>(value); } } public Queue> trainingRowsQueue; [Storable] private IReadOnlyList[] storablepruningRowsQueue { get { return pruningRowsQueue.ToArray(); } set { pruningRowsQueue = new Queue>(value); } } public Queue> pruningRowsQueue; //State.Code values denote the current action (for pausing) //0...nothing has been done; //1...building Models; //2...assigning threshold //3...adjusting threshold //4...pruning //5...finished [Storable] public int Code = 0; #region HLConstructors & Cloning [StorableConstructor] protected PruningState(StorableConstructorFlag _) : base(_) { } protected PruningState(PruningState original, Cloner cloner) : base(original, cloner) { modelComplexities = original.modelComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value); nodeComplexities = original.nodeComplexities.ToDictionary(x => cloner.Clone(x.Key), x => x.Value); pruningSizes = original.pruningSizes.ToDictionary(x => cloner.Clone(x.Key), x => x.Value); modelErrors = original.modelErrors.ToDictionary(x => cloner.Clone(x.Key), x => x.Value); nodeQueue = new Queue(original.nodeQueue.Select(cloner.Clone)); trainingRowsQueue = new Queue>(original.trainingRowsQueue.Select(x => (IReadOnlyList)x.ToArray())); pruningRowsQueue = new Queue>(original.pruningRowsQueue.Select(x => (IReadOnlyList)x.ToArray())); Code = original.Code; } public PruningState() { modelComplexities = new Dictionary(); nodeComplexities = new Dictionary(); pruningSizes = new Dictionary(); modelErrors = new Dictionary(); nodeQueue = new Queue(); trainingRowsQueue = new Queue>(); pruningRowsQueue = new Queue>(); } public override IDeepCloneable Clone(Cloner cloner) { return new PruningState(this, cloner); } #endregion public void FillTopDown(RegressionNodeTreeModel tree) { var helperQueue = new Queue(); nodeQueue.Clear(); helperQueue.Enqueue(tree.Root); nodeQueue.Enqueue(tree.Root); while (helperQueue.Count != 0) { var n = helperQueue.Dequeue(); if (n.IsLeaf) continue; helperQueue.Enqueue(n.Left); helperQueue.Enqueue(n.Right); nodeQueue.Enqueue(n.Left); nodeQueue.Enqueue(n.Right); } } public void FillTopDown(RegressionNodeTreeModel tree, IReadOnlyList pruningRows, IReadOnlyList trainingRows, IDataset data) { var helperQueue = new Queue(); var trainingHelperQueue = new Queue>(); var pruningHelperQueue = new Queue>(); nodeQueue.Clear(); trainingRowsQueue.Clear(); pruningRowsQueue.Clear(); helperQueue.Enqueue(tree.Root); trainingHelperQueue.Enqueue(trainingRows); pruningHelperQueue.Enqueue(pruningRows); nodeQueue.Enqueue(tree.Root); trainingRowsQueue.Enqueue(trainingRows); pruningRowsQueue.Enqueue(pruningRows); while (helperQueue.Count != 0) { var n = helperQueue.Dequeue(); var p = pruningHelperQueue.Dequeue(); var t = trainingHelperQueue.Dequeue(); if (n.IsLeaf) continue; IReadOnlyList leftPruning, rightPruning; RegressionTreeUtilities.SplitRows(p, data, n.SplitAttribute, n.SplitValue, out leftPruning, out rightPruning); IReadOnlyList leftTraining, rightTraining; RegressionTreeUtilities.SplitRows(t, data, n.SplitAttribute, n.SplitValue, out leftTraining, out rightTraining); helperQueue.Enqueue(n.Left); helperQueue.Enqueue(n.Right); trainingHelperQueue.Enqueue(leftTraining); trainingHelperQueue.Enqueue(rightTraining); pruningHelperQueue.Enqueue(leftPruning); pruningHelperQueue.Enqueue(rightPruning); nodeQueue.Enqueue(n.Left); nodeQueue.Enqueue(n.Right); trainingRowsQueue.Enqueue(leftTraining); trainingRowsQueue.Enqueue(rightTraining); pruningRowsQueue.Enqueue(leftPruning); pruningRowsQueue.Enqueue(rightPruning); } } public void FillBottomUp(RegressionNodeTreeModel tree) { FillTopDown(tree); nodeQueue = new Queue(nodeQueue.Reverse()); } public void FillBottomUp(RegressionNodeTreeModel tree, IReadOnlyList pruningRows, IReadOnlyList trainingRows, IDataset data) { FillTopDown(tree, pruningRows, trainingRows, data); nodeQueue = new Queue(nodeQueue.Reverse()); pruningRowsQueue = new Queue>(pruningRowsQueue.Reverse()); trainingRowsQueue = new Queue>(trainingRowsQueue.Reverse()); } } } }