#region License Information /* HeuristicLab * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) * * This file is part of HeuristicLab. * * HeuristicLab is free software: you can redistribute it and/or modify * it under the terms of the GNU General Public License as published by * the Free Software Foundation, either version 3 of the License, or * (at your option) any later version. * * HeuristicLab is distributed in the hope that it will be useful, * but WITHOUT ANY WARRANTY; without even the implied warranty of * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the * GNU General Public License for more details. * * You should have received a copy of the GNU General Public License * along with HeuristicLab. If not, see . */ #endregion using System.Collections.Generic; using System.Linq; using HeuristicLab.Core; using HeuristicLab.Data; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; using HeuristicLab.Operators; using HeuristicLab.Optimization; using HeuristicLab.Parameters; using HeuristicLab.Problems.DataAnalysis.Symbolic; using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols; using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; using System; using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols; namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers { public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer { private const string RandomParameterName = "Random"; private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree"; private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData"; private const string SamplesStartParameterName = "SamplesStart"; private const string SamplesEndParameterName = "SamplesEnd"; private const string EvaluatorParameterName = "Evaluator"; private const string MaximizationParameterName = "Maximization"; private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; private const string UpperEstimationLimitParameterName = "UpperEstimationLimit"; private const string LowerEstimationLimitParameterName = "LowerEstimationLimit"; private const string MaxPruningRatioParameterName = "MaxPruningRatio"; private const string TournamentSizeParameterName = "TournamentSize"; private const string PopulationPercentileStartParameterName = "PopulationPercentileStart"; private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd"; private const string QualityGainWeightParameterName = "QualityGainWeight"; private const string IterationsParameterName = "Iterations"; private const string FirstPruningGenerationParameterName = "FirstPruningGeneration"; private const string PruningFrequencyParameterName = "PruningFrequency"; private const string GenerationParameterName = "Generations"; private const string ResultsParameterName = "Results"; #region parameter properties public ILookupParameter RandomParameter { get { return (ILookupParameter)Parameters[RandomParameterName]; } } public ScopeTreeLookupParameter SymbolicExpressionTreeParameter { get { return (ScopeTreeLookupParameter)Parameters[SymbolicExpressionTreeParameterName]; } } public ScopeTreeLookupParameter QualityParameter { get { return (ScopeTreeLookupParameter)Parameters["Quality"]; } } public ILookupParameter DataAnalysisProblemDataParameter { get { return (ILookupParameter)Parameters[DataAnalysisProblemDataParameterName]; } } public ILookupParameter SymbolicExpressionTreeInterpreterParameter { get { return (ILookupParameter)Parameters[SymbolicExpressionTreeInterpreterParameterName]; } } public IValueLookupParameter UpperEstimationLimitParameter { get { return (IValueLookupParameter)Parameters[UpperEstimationLimitParameterName]; } } public IValueLookupParameter LowerEstimationLimitParameter { get { return (IValueLookupParameter)Parameters[LowerEstimationLimitParameterName]; } } public IValueLookupParameter SamplesStartParameter { get { return (IValueLookupParameter)Parameters[SamplesStartParameterName]; } } public IValueLookupParameter SamplesEndParameter { get { return (IValueLookupParameter)Parameters[SamplesEndParameterName]; } } public IValueLookupParameter RelativeNumberOfEvaluatedRowsParameters { get { return (IValueLookupParameter)Parameters["RelativeNumberOfEvaluatedRows"]; } } public ILookupParameter EvaluatorParameter { get { return (ILookupParameter)Parameters[EvaluatorParameterName]; } } public ILookupParameter MaximizationParameter { get { return (ILookupParameter)Parameters[MaximizationParameterName]; } } public IValueLookupParameter MaxPruningRatioParameter { get { return (IValueLookupParameter)Parameters[MaxPruningRatioParameterName]; } } public IValueLookupParameter TournamentSizeParameter { get { return (IValueLookupParameter)Parameters[TournamentSizeParameterName]; } } public IValueLookupParameter PopulationPercentileStartParameter { get { return (IValueLookupParameter)Parameters[PopulationPercentileStartParameterName]; } } public IValueLookupParameter PopulationPercentileEndParameter { get { return (IValueLookupParameter)Parameters[PopulationPercentileEndParameterName]; } } public IValueLookupParameter QualityGainWeightParameter { get { return (IValueLookupParameter)Parameters[QualityGainWeightParameterName]; } } public IValueLookupParameter IterationsParameter { get { return (IValueLookupParameter)Parameters[IterationsParameterName]; } } public IValueLookupParameter FirstPruningGenerationParameter { get { return (IValueLookupParameter)Parameters[FirstPruningGenerationParameterName]; } } public IValueLookupParameter PruningFrequencyParameter { get { return (IValueLookupParameter)Parameters[PruningFrequencyParameterName]; } } public ILookupParameter GenerationParameter { get { return (ILookupParameter)Parameters[GenerationParameterName]; } } public ILookupParameter ResultsParameter { get { return (ILookupParameter)Parameters[ResultsParameterName]; } } public IValueLookupParameter ApplyPruningParameter { get { return (IValueLookupParameter)Parameters["ApplyPruning"]; } } #endregion #region properties public IRandom Random { get { return RandomParameter.ActualValue; } } public ItemArray SymbolicExpressionTree { get { return SymbolicExpressionTreeParameter.ActualValue; } } public DataAnalysisProblemData DataAnalysisProblemData { get { return DataAnalysisProblemDataParameter.ActualValue; } } public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter { get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; } } public DoubleValue UpperEstimationLimit { get { return UpperEstimationLimitParameter.ActualValue; } } public DoubleValue LowerEstimationLimit { get { return LowerEstimationLimitParameter.ActualValue; } } public IntValue SamplesStart { get { return SamplesStartParameter.ActualValue; } } public IntValue SamplesEnd { get { return SamplesEndParameter.ActualValue; } } public ISymbolicRegressionEvaluator Evaluator { get { return EvaluatorParameter.ActualValue; } } public BoolValue Maximization { get { return MaximizationParameter.ActualValue; } } public DoubleValue MaxPruningRatio { get { return MaxPruningRatioParameter.ActualValue; } } public IntValue TournamentSize { get { return TournamentSizeParameter.ActualValue; } } public DoubleValue PopulationPercentileStart { get { return PopulationPercentileStartParameter.ActualValue; } } public DoubleValue PopulationPercentileEnd { get { return PopulationPercentileEndParameter.ActualValue; } } public DoubleValue QualityGainWeight { get { return QualityGainWeightParameter.ActualValue; } } public IntValue Iterations { get { return IterationsParameter.ActualValue; } } public IntValue PruningFrequency { get { return PruningFrequencyParameter.ActualValue; } } public IntValue FirstPruningGeneration { get { return FirstPruningGenerationParameter.ActualValue; } } public IntValue Generation { get { return GenerationParameter.ActualValue; } } #endregion [StorableConstructor] protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { } public SymbolicRegressionTournamentPruning() : base() { Parameters.Add(new LookupParameter(RandomParameterName, "A random number generator.")); Parameters.Add(new ScopeTreeLookupParameter(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune.")); Parameters.Add(new ScopeTreeLookupParameter("Quality")); Parameters.Add(new LookupParameter(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation.")); Parameters.Add(new LookupParameter(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation")); Parameters.Add(new ValueLookupParameter(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation.")); Parameters.Add(new ValueLookupParameter(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation.")); Parameters.Add(new LookupParameter(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant.")); Parameters.Add(new LookupParameter(MaximizationParameterName, "The direction of optimization.")); Parameters.Add(new ValueLookupParameter("ApplyPruning")); Parameters.Add(new ValueLookupParameter(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5))); Parameters.Add(new ValueLookupParameter(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10))); Parameters.Add(new ValueLookupParameter(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25))); Parameters.Add(new ValueLookupParameter(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75))); Parameters.Add(new ValueLookupParameter(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0))); Parameters.Add(new ValueLookupParameter(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation.")); Parameters.Add(new ValueLookupParameter(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation.")); Parameters.Add(new ValueLookupParameter(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1))); Parameters.Add(new ValueLookupParameter(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1))); Parameters.Add(new ValueLookupParameter(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1))); Parameters.Add(new LookupParameter(GenerationParameterName, "The current generation.")); Parameters.Add(new LookupParameter(ResultsParameterName, "The results collection.")); Parameters.Add(new ValueLookupParameter("RelativeNumberOfEvaluatedRows", new PercentValue(1.0))); } [StorableHook(HookType.AfterDeserialization)] private void AfterDeserialization() { #region compatibility remove before releasing 3.3.1 if (!Parameters.ContainsKey(EvaluatorParameterName)) { Parameters.Add(new LookupParameter(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set.")); } if (!Parameters.ContainsKey(MaximizationParameterName)) { Parameters.Add(new LookupParameter(MaximizationParameterName, "The direction of optimization.")); } if (!Parameters.ContainsKey("ApplyPruning")) { Parameters.Add(new ValueLookupParameter("ApplyPruning")); } if (!Parameters.ContainsKey("Quality")) { Parameters.Add(new ScopeTreeLookupParameter("Quality")); } if (!Parameters.ContainsKey("RelativeNumberOfEvaluatedRows")) { Parameters.Add(new ValueLookupParameter("RelativeNumberOfEvaluatedRows", new PercentValue(1.0))); } #endregion } public override IOperation Apply() { bool pruningCondition = (ApplyPruningParameter.ActualValue.Value) && (Generation.Value >= FirstPruningGeneration.Value) && ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0); if (pruningCondition) { int n = SymbolicExpressionTree.Length; double percentileStart = PopulationPercentileStart.Value; double percentileEnd = PopulationPercentileEnd.Value; // for each tree in the given percentile ItemArray trees = SymbolicExpressionTree; ItemArray quality = QualityParameter.ActualValue; bool maximization = Maximization.Value; var selectedTrees = (from index in Enumerable.Range(0, n) orderby maximization ? -quality[index].Value : quality[index].Value select new { Tree = trees[index], Quality = quality[index] }) .Skip((int)(n * percentileStart)) .Take((int)(n * (percentileEnd - percentileStart))); foreach (var pair in selectedTrees) { Prune(Random, pair.Tree, pair.Quality, Iterations.Value, TournamentSize.Value, DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value, RelativeNumberOfEvaluatedRowsParameters.ActualValue.Value, SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value, LowerEstimationLimit.Value, UpperEstimationLimit.Value, MaxPruningRatio.Value, QualityGainWeight.Value); } } return base.Apply(); } public static void Prune(IRandom random, SymbolicExpressionTree tree, DoubleValue quality, int iterations, int tournamentSize, DataAnalysisProblemData problemData, int samplesStart, int samplesEnd, double relativeNumberOfEvaluatedRows, ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization, double lowerEstimationLimit, double upperEstimationLimit, double maxPruningRatio, double qualityGainWeight) { int originalSize = tree.Size; // min size of the resulting pruned tree int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio)); // use the same subset of rows for all iterations and for all pruning tournaments IEnumerable rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows)); SymbolicExpressionTree prunedTree = tree; for (int iteration = 0; iteration < iterations; iteration++) { // maximally prune a branch such that the resulting tree size is not smaller than (1-maxPruningRatio) of the original tree int maxPrunedBranchSize = tree.Size - minPrunedSize; if (maxPrunedBranchSize > 0) { PruneTournament(prunedTree, quality, random, tournamentSize, maxPrunedBranchSize, maximization, qualityGainWeight, evaluator, interpreter, problemData.Dataset, problemData.TargetVariable.Value, rows, lowerEstimationLimit, upperEstimationLimit); } } } private class PruningPoint { public SymbolicExpressionTreeNode Parent { get; private set; } public SymbolicExpressionTreeNode Branch { get; private set; } public int SubTreeIndex { get; private set; } public PruningPoint(SymbolicExpressionTreeNode parent, SymbolicExpressionTreeNode branch, int index) { Parent = parent; Branch = branch; SubTreeIndex = index; } } private static void PruneTournament(SymbolicExpressionTree tree, DoubleValue quality, IRandom random, int tournamentSize, int maxPrunedBranchSize, bool maximization, double qualityGainWeight, ISymbolicRegressionEvaluator evaluator, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, string targetVariable, IEnumerable rows, double lowerEstimationLimit, double upperEstimationLimit) { // make a clone for pruningEvaluation SymbolicExpressionTree pruningEvaluationTree = (SymbolicExpressionTree)tree.Clone(); var prunePoints = (from node in pruningEvaluationTree.Root.SubTrees[0].IterateNodesPostfix() from subTree in node.SubTrees let subTreeSize = subTree.GetSize() where subTreeSize <= maxPrunedBranchSize where !(subTree.Symbol is Constant) select new PruningPoint(node, subTree, node.SubTrees.IndexOf(subTree))) .ToList(); double originalQuality = quality.Value; double originalSize = tree.Size; if (prunePoints.Count > 0) { double bestCoeff = double.PositiveInfinity; List tournamentGroup; if (prunePoints.Count > tournamentSize) { tournamentGroup = new List(); for (int i = 0; i < tournamentSize; i++) { tournamentGroup.Add(prunePoints.SelectRandom(random)); } } else { tournamentGroup = prunePoints; } foreach (PruningPoint prunePoint in tournamentGroup) { double replacementValue = CalculateReplacementValue(prunePoint.Branch, interpreter, ds, rows); // temporarily replace the branch with a constant prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex); var constNode = CreateConstant(replacementValue); prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, constNode); // evaluate the pruned tree double prunedQuality = evaluator.Evaluate(interpreter, pruningEvaluationTree, lowerEstimationLimit, upperEstimationLimit, ds, targetVariable, rows); double prunedSize = originalSize - prunePoint.Branch.GetSize() + 1; double coeff = CalculatePruningCoefficient(maximization, qualityGainWeight, originalQuality, originalSize, prunedQuality, prunedSize); if (coeff < bestCoeff) { bestCoeff = coeff; // clone the currently pruned tree SymbolicExpressionTree bestTree = (SymbolicExpressionTree)pruningEvaluationTree.Clone(); // and update original tree and quality tree.Root = bestTree.Root; quality.Value = prunedQuality; } // restore tree that is used for pruning evaluation prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex); prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, prunePoint.Branch); } } } private static double CalculatePruningCoefficient(bool maximization, double qualityGainWeight, double originalQuality, double originalSize, double prunedQuality, double prunedSize) { // deteriation in quality: // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement // MSE : newMse > origMse (deteriation) => prefer the smaller deteriation // MSE : minimize: newMse / origMse // R˛ : newR˛ > origR˛ (improvment) => prefer the larger improvment // R˛ : newR˛ < origR˛ (deteriation) => prefer smaller deteriation // R˛ : minimize: origR˛ / newR˛ double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality; // size of the pruned tree is always smaller than the size of the original tree // same change in quality => prefer pruning operation that removes a larger tree return (qualityDeteriation * qualityGainWeight) / (originalSize / prunedSize); } private static double CalculateReplacementValue(SymbolicExpressionTreeNode branch, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, IEnumerable rows) { SymbolicExpressionTreeNode start = (new StartSymbol()).CreateTreeNode(); start.AddSubTree(branch); SymbolicExpressionTreeNode root = (new ProgramRootSymbol()).CreateTreeNode(); root.AddSubTree(start); SymbolicExpressionTree tree = new SymbolicExpressionTree(root); IEnumerable branchValues = interpreter.GetSymbolicExpressionTreeValues(tree, ds, rows); return branchValues.Average(); } private static SymbolicExpressionTreeNode CreateConstant(double constantValue) { var node = (ConstantTreeNode)(new Constant()).CreateTreeNode(); node.Value = constantValue; return node; } } }