source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4350

Last change on this file since 4350 was 4350, checked in by gkronber, 12 years ago

added minimal size parameter for pruning operator. #1142

File size: 23.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using System;
34using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
35
36namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
37  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
38    private const string RandomParameterName = "Random";
39    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
40    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
41    private const string SamplesStartParameterName = "SamplesStart";
42    private const string SamplesEndParameterName = "SamplesEnd";
43    private const string EvaluatorParameterName = "Evaluator";
44    private const string MaximizationParameterName = "Maximization";
45    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
46    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
47    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
48    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
49    private const string TournamentSizeParameterName = "TournamentSize";
50    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
51    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
52    private const string QualityGainWeightParameterName = "QualityGainWeight";
53    private const string IterationsParameterName = "Iterations";
54    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
55    private const string PruningFrequencyParameterName = "PruningFrequency";
56    private const string GenerationParameterName = "Generations";
57    private const string ResultsParameterName = "Results";
58
59    #region parameter properties
60    public ILookupParameter<IRandom> RandomParameter {
61      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
62    }
63    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
64      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
65    }
66    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
67      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
68    }
69    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
70      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
71    }
72    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
73      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
74    }
75    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
76      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
77    }
78    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
79      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> SamplesStartParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
83    }
84    public IValueLookupParameter<IntValue> SamplesEndParameter {
85      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
86    }
87    public IValueLookupParameter<PercentValue> RelativeNumberOfEvaluatedRowsParameters {
88      get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeNumberOfEvaluatedRows"]; }
89    }
90    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
91      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
92    }
93    public ILookupParameter<BoolValue> MaximizationParameter {
94      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
95    }
96    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
97      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
98    }
99    public IValueLookupParameter<IntValue> TournamentSizeParameter {
100      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
101    }
102    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
103      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
104    }
105    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
106      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
107    }
108    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
109      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
110    }
111    public IValueLookupParameter<IntValue> IterationsParameter {
112      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
113    }
114    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
115      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
116    }
117    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
118      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
119    }
120    public ILookupParameter<IntValue> GenerationParameter {
121      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
122    }
123    public ILookupParameter<ResultCollection> ResultsParameter {
124      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
125    }
126    public IValueLookupParameter<BoolValue> ApplyPruningParameter {
127      get { return (IValueLookupParameter<BoolValue>)Parameters["ApplyPruning"]; }
128    }
129    public IValueLookupParameter<IntValue> MinimalTreeSizeParameter {
130      get { return (IValueLookupParameter<IntValue>)Parameters["MinimalTreeSize"]; }
131    }
132    #endregion
133    #region properties
134    public IRandom Random {
135      get { return RandomParameter.ActualValue; }
136    }
137    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
138      get { return SymbolicExpressionTreeParameter.ActualValue; }
139    }
140    public DataAnalysisProblemData DataAnalysisProblemData {
141      get { return DataAnalysisProblemDataParameter.ActualValue; }
142    }
143    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
144      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
145    }
146    public DoubleValue UpperEstimationLimit {
147      get { return UpperEstimationLimitParameter.ActualValue; }
148    }
149    public DoubleValue LowerEstimationLimit {
150      get { return LowerEstimationLimitParameter.ActualValue; }
151    }
152    public IntValue SamplesStart {
153      get { return SamplesStartParameter.ActualValue; }
154    }
155    public IntValue SamplesEnd {
156      get { return SamplesEndParameter.ActualValue; }
157    }
158    public ISymbolicRegressionEvaluator Evaluator {
159      get { return EvaluatorParameter.ActualValue; }
160    }
161    public BoolValue Maximization {
162      get { return MaximizationParameter.ActualValue; }
163    }
164    public DoubleValue MaxPruningRatio {
165      get { return MaxPruningRatioParameter.ActualValue; }
166    }
167    public IntValue TournamentSize {
168      get { return TournamentSizeParameter.ActualValue; }
169    }
170    public DoubleValue PopulationPercentileStart {
171      get { return PopulationPercentileStartParameter.ActualValue; }
172    }
173    public DoubleValue PopulationPercentileEnd {
174      get { return PopulationPercentileEndParameter.ActualValue; }
175    }
176    public DoubleValue QualityGainWeight {
177      get { return QualityGainWeightParameter.ActualValue; }
178    }
179    public IntValue Iterations {
180      get { return IterationsParameter.ActualValue; }
181    }
182    public IntValue PruningFrequency {
183      get { return PruningFrequencyParameter.ActualValue; }
184    }
185    public IntValue FirstPruningGeneration {
186      get { return FirstPruningGenerationParameter.ActualValue; }
187    }
188    public IntValue Generation {
189      get { return GenerationParameter.ActualValue; }
190    }
191    #endregion
192    [StorableConstructor]
193    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
194    public SymbolicRegressionTournamentPruning()
195      : base() {
196      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
197      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
198      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
199      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
200      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
201      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
202      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
203      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
204      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
205      Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
206      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
207      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
208      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
209      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
210      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
211      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
212      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
213      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
214      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
215      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
216      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
217      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
218      Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
219      Parameters.Add(new ValueLookupParameter<IntValue>("MinimalTreeSize", new IntValue(15)));
220    }
221
222    [StorableHook(HookType.AfterDeserialization)]
223    private void AfterDeserialization() {
224      #region compatibility remove before releasing 3.3.1
225      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
226        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
227      }
228      if (!Parameters.ContainsKey(MaximizationParameterName)) {
229        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
230      }
231      if (!Parameters.ContainsKey("ApplyPruning")) {
232        Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
233      }
234      if (!Parameters.ContainsKey("Quality")) {
235        Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
236      }
237      if (!Parameters.ContainsKey("RelativeNumberOfEvaluatedRows")) {
238        Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
239      }
240      if (!Parameters.ContainsKey("MinimalTreeSize")) {
241        Parameters.Add(new ValueLookupParameter<IntValue>("MinimalTreeSize", new IntValue(15)));
242      }
243
244      #endregion
245    }
246
247    public override IOperation Apply() {
248      bool pruningCondition =
249        (ApplyPruningParameter.ActualValue.Value) &&
250        (Generation.Value >= FirstPruningGeneration.Value) &&
251        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
252      if (pruningCondition) {
253        int n = SymbolicExpressionTree.Length;
254        double percentileStart = PopulationPercentileStart.Value;
255        double percentileEnd = PopulationPercentileEnd.Value;
256        // for each tree in the given percentile
257        ItemArray<SymbolicExpressionTree> trees = SymbolicExpressionTree;
258        ItemArray<DoubleValue> quality = QualityParameter.ActualValue;
259        bool maximization = Maximization.Value;
260        var selectedTrees = (from index in Enumerable.Range(0, n)
261                             orderby maximization ? -quality[index].Value : quality[index].Value
262                             select new { Tree = trees[index], Quality = quality[index] })
263                                                            .Skip((int)(n * percentileStart))
264                                                            .Take((int)(n * (percentileEnd - percentileStart)));
265        foreach (var pair in selectedTrees) {
266          Prune(Random, pair.Tree, pair.Quality, Iterations.Value, TournamentSize.Value,
267            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value, RelativeNumberOfEvaluatedRowsParameters.ActualValue.Value,
268            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
269            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
270            MinimalTreeSizeParameter.ActualValue.Value, MaxPruningRatio.Value, QualityGainWeight.Value);
271        }
272      }
273      return base.Apply();
274    }
275
276    public static void Prune(IRandom random, SymbolicExpressionTree tree, DoubleValue quality, int iterations, int tournamentSize,
277      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd, double relativeNumberOfEvaluatedRows,
278      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
279      double lowerEstimationLimit, double upperEstimationLimit,
280      int minTreeSize, double maxPruningRatio, double qualityGainWeight) {
281
282     
283      int originalSize = tree.Size;
284      // min size of the resulting pruned tree
285      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
286      minPrunedSize = Math.Max(minPrunedSize, minTreeSize);
287
288      // use the same subset of rows for all iterations and for all pruning tournaments
289      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows));
290      SymbolicExpressionTree prunedTree = tree;
291      for (int iteration = 0; iteration < iterations; iteration++) {
292        // maximally prune a branch such that the resulting tree size is not smaller than (1-maxPruningRatio) of the original tree
293        int maxPrunedBranchSize = tree.Size - minPrunedSize;
294        if (maxPrunedBranchSize > 0) {
295          PruneTournament(prunedTree, quality, random, tournamentSize, maxPrunedBranchSize, maximization, qualityGainWeight, evaluator, interpreter, problemData.Dataset, problemData.TargetVariable.Value, rows, lowerEstimationLimit, upperEstimationLimit);
296        }
297      }
298    }
299
300    private class PruningPoint {
301      public SymbolicExpressionTreeNode Parent { get; private set; }
302      public SymbolicExpressionTreeNode Branch { get; private set; }
303      public int SubTreeIndex { get; private set; }
304      public PruningPoint(SymbolicExpressionTreeNode parent, SymbolicExpressionTreeNode branch, int index) {
305        Parent = parent;
306        Branch = branch;
307        SubTreeIndex = index;
308      }
309    }
310
311    private static void PruneTournament(SymbolicExpressionTree tree, DoubleValue quality, IRandom random, int tournamentSize,
312      int maxPrunedBranchSize, bool maximization, double qualityGainWeight, ISymbolicRegressionEvaluator evaluator, ISymbolicExpressionTreeInterpreter interpreter,
313      Dataset ds, string targetVariable, IEnumerable<int> rows, double lowerEstimationLimit, double upperEstimationLimit) {
314      // make a clone for pruningEvaluation
315      SymbolicExpressionTree pruningEvaluationTree = (SymbolicExpressionTree)tree.Clone();
316      var prunePoints = (from node in pruningEvaluationTree.Root.SubTrees[0].IterateNodesPostfix()
317                         from subTree in node.SubTrees
318                         let subTreeSize = subTree.GetSize()
319                         where subTreeSize <= maxPrunedBranchSize
320                         where !(subTree.Symbol is Constant)
321                         select new PruningPoint(node, subTree, node.SubTrees.IndexOf(subTree)))
322         .ToList();
323      double originalQuality = quality.Value;
324      double originalSize = tree.Size;
325      if (prunePoints.Count > 0) {
326        double bestCoeff = double.PositiveInfinity;
327        List<PruningPoint> tournamentGroup;
328        if (prunePoints.Count > tournamentSize) {
329          tournamentGroup = new List<PruningPoint>();
330          for (int i = 0; i < tournamentSize; i++) {
331            tournamentGroup.Add(prunePoints.SelectRandom(random));
332          }
333        } else {
334          tournamentGroup = prunePoints;
335        }
336        foreach (PruningPoint prunePoint in tournamentGroup) {
337          double replacementValue = CalculateReplacementValue(prunePoint.Branch, interpreter, ds, rows);
338
339          // temporarily replace the branch with a constant
340          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
341          var constNode = CreateConstant(replacementValue);
342          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, constNode);
343
344          // evaluate the pruned tree
345          double prunedQuality = evaluator.Evaluate(interpreter, pruningEvaluationTree,
346  lowerEstimationLimit, upperEstimationLimit, ds, targetVariable, rows);
347
348          double prunedSize = originalSize - prunePoint.Branch.GetSize() + 1;
349
350          double coeff = CalculatePruningCoefficient(maximization, qualityGainWeight, originalQuality, originalSize, prunedQuality, prunedSize);
351          if (coeff < bestCoeff) {
352            bestCoeff = coeff;
353            // clone the currently pruned tree
354            SymbolicExpressionTree bestTree = (SymbolicExpressionTree)pruningEvaluationTree.Clone();
355
356            // and update original tree and quality
357            tree.Root = bestTree.Root;
358            quality.Value = prunedQuality;
359          }
360
361          // restore tree that is used for pruning evaluation
362          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
363          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, prunePoint.Branch);
364        }
365      }
366    }
367
368    private static double CalculatePruningCoefficient(bool maximization, double qualityGainWeight, double originalQuality, double originalSize, double prunedQuality, double prunedSize) {
369      // deteriation in quality:
370      // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
371      //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
372      //      MSE : minimize: newMse / origMse
373      //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
374      //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
375      //      R²  : minimize: origR² / newR²
376      double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
377      // size of the pruned tree is always smaller than the size of the original tree
378      // same change in quality => prefer pruning operation that removes a larger tree
379      return (qualityDeteriation * qualityGainWeight) / (originalSize / prunedSize);
380    }
381
382    private static double CalculateReplacementValue(SymbolicExpressionTreeNode branch, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, IEnumerable<int> rows) {
383      SymbolicExpressionTreeNode start = (new StartSymbol()).CreateTreeNode();
384      start.AddSubTree(branch);
385      SymbolicExpressionTreeNode root = (new ProgramRootSymbol()).CreateTreeNode();
386      root.AddSubTree(start);
387      SymbolicExpressionTree tree = new SymbolicExpressionTree(root);
388      IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(tree, ds, rows);
389      return branchValues.Average();
390    }
391
392    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
393      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
394      node.Value = constantValue;
395      return node;
396    }
397  }
398}
Note: See TracBrowser for help on using the repository browser.