Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4328

Last change on this file since 4328 was 4328, checked in by gkronber, 14 years ago

Overhauled pruning operator. #1142

File size: 22.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using System;
34using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
35
36namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
37  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
38    private const string RandomParameterName = "Random";
39    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
40    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
41    private const string SamplesStartParameterName = "SamplesStart";
42    private const string SamplesEndParameterName = "SamplesEnd";
43    private const string EvaluatorParameterName = "Evaluator";
44    private const string MaximizationParameterName = "Maximization";
45    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
46    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
47    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
48    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
49    private const string TournamentSizeParameterName = "TournamentSize";
50    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
51    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
52    private const string QualityGainWeightParameterName = "QualityGainWeight";
53    private const string IterationsParameterName = "Iterations";
54    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
55    private const string PruningFrequencyParameterName = "PruningFrequency";
56    private const string GenerationParameterName = "Generations";
57    private const string ResultsParameterName = "Results";
58
59    #region parameter properties
60    public ILookupParameter<IRandom> RandomParameter {
61      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
62    }
63    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
64      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
65    }
66    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
67      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
68    }
69    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
70      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
71    }
72    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
73      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
74    }
75    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
76      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
77    }
78    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
79      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> SamplesStartParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
83    }
84    public IValueLookupParameter<IntValue> SamplesEndParameter {
85      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
86    }
87    public IValueLookupParameter<PercentValue> RelativeNumberOfEvaluatedRowsParameters {
88      get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeNumberOfEvaluatedRows"]; }
89    }
90    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
91      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
92    }
93    public ILookupParameter<BoolValue> MaximizationParameter {
94      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
95    }
96    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
97      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
98    }
99    public IValueLookupParameter<IntValue> TournamentSizeParameter {
100      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
101    }
102    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
103      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
104    }
105    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
106      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
107    }
108    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
109      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
110    }
111    public IValueLookupParameter<IntValue> IterationsParameter {
112      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
113    }
114    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
115      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
116    }
117    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
118      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
119    }
120    public ILookupParameter<IntValue> GenerationParameter {
121      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
122    }
123    public ILookupParameter<ResultCollection> ResultsParameter {
124      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
125    }
126    public IValueLookupParameter<BoolValue> ApplyPruningParameter {
127      get { return (IValueLookupParameter<BoolValue>)Parameters["ApplyPruning"]; }
128    }
129    #endregion
130    #region properties
131    public IRandom Random {
132      get { return RandomParameter.ActualValue; }
133    }
134    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
135      get { return SymbolicExpressionTreeParameter.ActualValue; }
136    }
137    public DataAnalysisProblemData DataAnalysisProblemData {
138      get { return DataAnalysisProblemDataParameter.ActualValue; }
139    }
140    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
141      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
142    }
143    public DoubleValue UpperEstimationLimit {
144      get { return UpperEstimationLimitParameter.ActualValue; }
145    }
146    public DoubleValue LowerEstimationLimit {
147      get { return LowerEstimationLimitParameter.ActualValue; }
148    }
149    public IntValue SamplesStart {
150      get { return SamplesStartParameter.ActualValue; }
151    }
152    public IntValue SamplesEnd {
153      get { return SamplesEndParameter.ActualValue; }
154    }
155    public ISymbolicRegressionEvaluator Evaluator {
156      get { return EvaluatorParameter.ActualValue; }
157    }
158    public BoolValue Maximization {
159      get { return MaximizationParameter.ActualValue; }
160    }
161    public DoubleValue MaxPruningRatio {
162      get { return MaxPruningRatioParameter.ActualValue; }
163    }
164    public IntValue TournamentSize {
165      get { return TournamentSizeParameter.ActualValue; }
166    }
167    public DoubleValue PopulationPercentileStart {
168      get { return PopulationPercentileStartParameter.ActualValue; }
169    }
170    public DoubleValue PopulationPercentileEnd {
171      get { return PopulationPercentileEndParameter.ActualValue; }
172    }
173    public DoubleValue QualityGainWeight {
174      get { return QualityGainWeightParameter.ActualValue; }
175    }
176    public IntValue Iterations {
177      get { return IterationsParameter.ActualValue; }
178    }
179    public IntValue PruningFrequency {
180      get { return PruningFrequencyParameter.ActualValue; }
181    }
182    public IntValue FirstPruningGeneration {
183      get { return FirstPruningGenerationParameter.ActualValue; }
184    }
185    public IntValue Generation {
186      get { return GenerationParameter.ActualValue; }
187    }
188    #endregion
189    [StorableConstructor]
190    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
191    public SymbolicRegressionTournamentPruning()
192      : base() {
193      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
194      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
195      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
196      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
197      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
198      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
199      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
200      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
201      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
202      Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
203      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
204      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
205      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
206      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
207      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
208      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
209      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
210      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
211      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
212      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
213      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
214      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
215      Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
216    }
217
218    [StorableHook(HookType.AfterDeserialization)]
219    private void AfterDeserialization() {
220      #region compatibility remove before releasing 3.3.1
221      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
222        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
223      }
224      if (!Parameters.ContainsKey(MaximizationParameterName)) {
225        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
226      }
227      if (!Parameters.ContainsKey("ApplyPruning")) {
228        Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
229      }
230      if (!Parameters.ContainsKey("Quality")) {
231        Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
232      }
233      if (!Parameters.ContainsKey("RelativeNumberOfEvaluatedRows")) {
234        Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
235      }
236
237      #endregion
238    }
239
240    public override IOperation Apply() {
241      bool pruningCondition =
242        (ApplyPruningParameter.ActualValue.Value) &&
243        (Generation.Value >= FirstPruningGeneration.Value) &&
244        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
245      if (pruningCondition) {
246        int n = SymbolicExpressionTree.Length;
247        double percentileStart = PopulationPercentileStart.Value;
248        double percentileEnd = PopulationPercentileEnd.Value;
249        // for each tree in the given percentile
250        ItemArray<SymbolicExpressionTree> trees = SymbolicExpressionTree;
251        ItemArray<DoubleValue> quality = QualityParameter.ActualValue;
252        bool maximization = Maximization.Value;
253        var selectedTrees = (from index in Enumerable.Range(0, n)
254                             orderby maximization ? -quality[index].Value : quality[index].Value
255                             select new { Tree = trees[index], Quality = quality[index] })
256                                                            .Skip((int)(n * percentileStart))
257                                                            .Take((int)(n * (percentileEnd - percentileStart)));
258        foreach (var pair in selectedTrees) {
259          Prune(Random, pair.Tree, pair.Quality, Iterations.Value, TournamentSize.Value,
260            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value, RelativeNumberOfEvaluatedRowsParameters.ActualValue.Value,
261            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
262            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
263            MaxPruningRatio.Value, QualityGainWeight.Value);
264        }
265      }
266      return base.Apply();
267    }
268
269    public static void Prune(IRandom random, SymbolicExpressionTree tree, DoubleValue quality, int iterations, int tournamentSize,
270      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd, double relativeNumberOfEvaluatedRows,
271      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
272      double lowerEstimationLimit, double upperEstimationLimit,
273      double maxPruningRatio, double qualityGainWeight) {
274
275      int originalSize = tree.Size;
276
277      // min size of the resulting pruned tree
278      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
279
280      // use the same subset of rows for all iterations and for all pruning tournaments
281      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows));
282      SymbolicExpressionTree prunedTree = tree;
283      for (int iteration = 0; iteration < iterations; iteration++) {
284        // maximally prune a branch such that the resulting tree size is not smaller than (1-maxPruningRatio) of the original tree
285        int maxPrunedBranchSize = tree.Size - minPrunedSize;
286        if (maxPrunedBranchSize > 0) {
287          PruneTournament(prunedTree, quality, random, tournamentSize, maxPrunedBranchSize, maximization, qualityGainWeight, evaluator, interpreter, problemData.Dataset, problemData.TargetVariable.Value, rows, lowerEstimationLimit, upperEstimationLimit);
288        }
289      }
290    }
291
292    private class PruningPoint {
293      public SymbolicExpressionTreeNode Parent { get; private set; }
294      public SymbolicExpressionTreeNode Branch { get; private set; }
295      public int SubTreeIndex { get; private set; }
296      public PruningPoint(SymbolicExpressionTreeNode parent, SymbolicExpressionTreeNode branch, int index) {
297        Parent = parent;
298        Branch = branch;
299        SubTreeIndex = index;
300      }
301    }
302
303    private static void PruneTournament(SymbolicExpressionTree tree, DoubleValue quality, IRandom random, int tournamentSize,
304      int maxPrunedBranchSize, bool maximization, double qualityGainWeight, ISymbolicRegressionEvaluator evaluator, ISymbolicExpressionTreeInterpreter interpreter,
305      Dataset ds, string targetVariable, IEnumerable<int> rows, double lowerEstimationLimit, double upperEstimationLimit) {
306      // make a clone for pruningEvaluation
307      SymbolicExpressionTree pruningEvaluationTree = (SymbolicExpressionTree)tree.Clone();
308      var prunePoints = (from node in pruningEvaluationTree.Root.SubTrees[0].IterateNodesPostfix()
309                         from subTree in node.SubTrees
310                         let subTreeSize = subTree.GetSize()
311                         where subTreeSize <= maxPrunedBranchSize
312                         where !(subTree.Symbol is Constant)
313                         select new PruningPoint(node, subTree, node.SubTrees.IndexOf(subTree)))
314         .ToList();
315      double originalQuality = quality.Value;
316      double originalSize = tree.Size;
317      if (prunePoints.Count > 0) {
318        double bestCoeff = double.PositiveInfinity;
319        List<PruningPoint> tournamentGroup;
320        if (prunePoints.Count > tournamentSize) {
321          tournamentGroup = new List<PruningPoint>();
322          for (int i = 0; i < tournamentSize; i++) {
323            tournamentGroup.Add(prunePoints.SelectRandom(random));
324          }
325        } else {
326          tournamentGroup = prunePoints;
327        }
328        foreach (PruningPoint prunePoint in tournamentGroup) {
329          double replacementValue = CalculateReplacementValue(prunePoint.Branch, interpreter, ds, rows);
330
331          // temporarily replace the branch with a constant
332          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
333          var constNode = CreateConstant(replacementValue);
334          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, constNode);
335
336          // evaluate the pruned tree
337          double prunedQuality = evaluator.Evaluate(interpreter, pruningEvaluationTree,
338  lowerEstimationLimit, upperEstimationLimit, ds, targetVariable, rows);
339
340          double prunedSize = originalSize - prunePoint.Branch.GetSize() + 1;
341
342          double coeff = CalculatePruningCoefficient(maximization, qualityGainWeight, originalQuality, originalSize, prunedQuality, prunedSize);
343          if (coeff < bestCoeff) {
344            bestCoeff = coeff;
345            // clone the currently pruned tree
346            SymbolicExpressionTree bestTree = (SymbolicExpressionTree)pruningEvaluationTree.Clone();
347
348            // and update original tree and quality
349            tree.Root = bestTree.Root;
350            quality.Value = prunedQuality;
351          }
352
353          // restore tree that is used for pruning evaluation
354          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
355          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, prunePoint.Branch);
356        }
357      }
358    }
359
360    private static double CalculatePruningCoefficient(bool maximization, double qualityGainWeight, double originalQuality, double originalSize, double prunedQuality, double prunedSize) {
361      // deteriation in quality:
362      // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
363      //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
364      //      MSE : minimize: newMse / origMse
365      //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
366      //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
367      //      R²  : minimize: origR² / newR²
368      double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
369      // size of the pruned tree is always smaller than the size of the original tree
370      // same change in quality => prefer pruning operation that removes a larger tree
371      return (qualityDeteriation * qualityGainWeight) / (originalSize / prunedSize);
372    }
373
374    private static double CalculateReplacementValue(SymbolicExpressionTreeNode branch, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, IEnumerable<int> rows) {
375      SymbolicExpressionTreeNode start = (new StartSymbol()).CreateTreeNode();
376      start.AddSubTree(branch);
377      SymbolicExpressionTreeNode root = (new ProgramRootSymbol()).CreateTreeNode();
378      root.AddSubTree(start);
379      SymbolicExpressionTree tree = new SymbolicExpressionTree(root);
380      IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(tree, ds, rows);
381      return branchValues.Average();
382    }
383
384    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
385      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
386      node.Value = constantValue;
387      return node;
388    }
389  }
390}
Note: See TracBrowser for help on using the repository browser.