Free cookie consent management tool by TermsFeed Policy Generator

source: branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4297

Last change on this file since 4297 was 4297, checked in by gkronber, 14 years ago

Added output parameter for validation quality to validation analyzer, added input parameter for validation quality to overfitting analyzer, and fixed bugs in pruning operator. #1142

File size: 20.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using System;
34
35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
42    private const string EvaluatorParameterName = "Evaluator";
43    private const string MaximizationParameterName = "Maximization";
44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
46    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
47    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
48    private const string TournamentSizeParameterName = "TournamentSize";
49    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
50    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
51    private const string QualityGainWeightParameterName = "QualityGainWeight";
52    private const string IterationsParameterName = "Iterations";
53    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
54    private const string PruningFrequencyParameterName = "PruningFrequency";
55    private const string GenerationParameterName = "Generations";
56    private const string ResultsParameterName = "Results";
57
58    #region parameter properties
59    public ILookupParameter<IRandom> RandomParameter {
60      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
61    }
62    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
63      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
64    }
65    public ScopeTreeLookupParameter<DoubleValue> QualityParameter {
66      get { return (ScopeTreeLookupParameter<DoubleValue>)Parameters["Quality"]; }
67    }
68    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
69      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
70    }
71    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
72      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
73    }
74    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
75      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
76    }
77    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
78      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
79    }
80    public IValueLookupParameter<IntValue> SamplesStartParameter {
81      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
82    }
83    public IValueLookupParameter<IntValue> SamplesEndParameter {
84      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
85    }
86    public IValueLookupParameter<PercentValue> RelativeNumberOfEvaluatedRowsParameters {
87      get { return (IValueLookupParameter<PercentValue>)Parameters["RelativeNumberOfEvaluatedRows"]; }
88    }
89    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
90      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
91    }
92    public ILookupParameter<BoolValue> MaximizationParameter {
93      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
94    }
95    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
96      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
97    }
98    public IValueLookupParameter<IntValue> TournamentSizeParameter {
99      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
100    }
101    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
102      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
103    }
104    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
105      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
106    }
107    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
108      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
109    }
110    public IValueLookupParameter<IntValue> IterationsParameter {
111      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
112    }
113    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
114      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
115    }
116    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
117      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
118    }
119    public ILookupParameter<IntValue> GenerationParameter {
120      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
121    }
122    public ILookupParameter<ResultCollection> ResultsParameter {
123      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
124    }
125    public IValueLookupParameter<BoolValue> ApplyPruningParameter {
126      get { return (IValueLookupParameter<BoolValue>)Parameters["ApplyPruning"]; }
127    }
128    #endregion
129    #region properties
130    public IRandom Random {
131      get { return RandomParameter.ActualValue; }
132    }
133    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
134      get { return SymbolicExpressionTreeParameter.ActualValue; }
135    }
136    public DataAnalysisProblemData DataAnalysisProblemData {
137      get { return DataAnalysisProblemDataParameter.ActualValue; }
138    }
139    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
140      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
141    }
142    public DoubleValue UpperEstimationLimit {
143      get { return UpperEstimationLimitParameter.ActualValue; }
144    }
145    public DoubleValue LowerEstimationLimit {
146      get { return LowerEstimationLimitParameter.ActualValue; }
147    }
148    public IntValue SamplesStart {
149      get { return SamplesStartParameter.ActualValue; }
150    }
151    public IntValue SamplesEnd {
152      get { return SamplesEndParameter.ActualValue; }
153    }
154    public ISymbolicRegressionEvaluator Evaluator {
155      get { return EvaluatorParameter.ActualValue; }
156    }
157    public BoolValue Maximization {
158      get { return MaximizationParameter.ActualValue; }
159    }
160    public DoubleValue MaxPruningRatio {
161      get { return MaxPruningRatioParameter.ActualValue; }
162    }
163    public IntValue TournamentSize {
164      get { return TournamentSizeParameter.ActualValue; }
165    }
166    public DoubleValue PopulationPercentileStart {
167      get { return PopulationPercentileStartParameter.ActualValue; }
168    }
169    public DoubleValue PopulationPercentileEnd {
170      get { return PopulationPercentileEndParameter.ActualValue; }
171    }
172    public DoubleValue QualityGainWeight {
173      get { return QualityGainWeightParameter.ActualValue; }
174    }
175    public IntValue Iterations {
176      get { return IterationsParameter.ActualValue; }
177    }
178    public IntValue PruningFrequency {
179      get { return PruningFrequencyParameter.ActualValue; }
180    }
181    public IntValue FirstPruningGeneration {
182      get { return FirstPruningGenerationParameter.ActualValue; }
183    }
184    public IntValue Generation {
185      get { return GenerationParameter.ActualValue; }
186    }
187    #endregion
188    [StorableConstructor]
189    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
190    public SymbolicRegressionTournamentPruning()
191      : base() {
192      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
193      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
194      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
195      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
196      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
197      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
198      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
199      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
200      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
201      Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
202      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
203      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
204      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
205      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
206      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
207      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
208      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
209      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
210      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
211      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
212      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
213      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
214      Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
215    }
216
217    [StorableHook(HookType.AfterDeserialization)]
218    private void AfterDeserialization() {
219      #region compatibility remove before releasing 3.3.1
220      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
221        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
222      }
223      if (!Parameters.ContainsKey(MaximizationParameterName)) {
224        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
225      }
226      if (!Parameters.ContainsKey("ApplyPruning")) {
227        Parameters.Add(new ValueLookupParameter<BoolValue>("ApplyPruning"));
228      }
229      if (!Parameters.ContainsKey("Quality")) {
230        Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>("Quality"));
231      }
232      if (!Parameters.ContainsKey("RelativeNumberOfEvaluatedRows")) {
233        Parameters.Add(new ValueLookupParameter<PercentValue>("RelativeNumberOfEvaluatedRows", new PercentValue(1.0)));
234      }
235
236      #endregion
237    }
238
239    public override IOperation Apply() {
240      bool pruningCondition =
241        (ApplyPruningParameter.ActualValue.Value) &&
242        (Generation.Value >= FirstPruningGeneration.Value) &&
243        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
244      if (pruningCondition) {
245        int n = SymbolicExpressionTree.Length;
246        double percentileStart = PopulationPercentileStart.Value;
247        double percentileEnd = PopulationPercentileEnd.Value;
248        // for each tree in the given percentile
249        ItemArray<SymbolicExpressionTree> trees = SymbolicExpressionTree;
250        ItemArray<DoubleValue> quality = QualityParameter.ActualValue;
251        bool maximization = Maximization.Value;
252        var selectedTrees = (from index in Enumerable.Range(0, n)
253                             orderby maximization ? -quality[index].Value : quality[index].Value
254                             select new { Tree = trees[index], Quality = quality[index] })
255                                                            .Skip((int)(n * percentileStart))
256                                                            .Take((int)(n * (percentileEnd - percentileStart)));
257        foreach (var pair in selectedTrees) {
258          Prune(Random, pair.Tree, pair.Quality, Iterations.Value, TournamentSize.Value,
259            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value, RelativeNumberOfEvaluatedRowsParameters.ActualValue.Value,
260            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
261            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
262            MaxPruningRatio.Value, QualityGainWeight.Value);
263        }
264      }
265      return base.Apply();
266    }
267
268    public static void Prune(IRandom random, SymbolicExpressionTree tree, DoubleValue quality, int iterations, int tournamentSize,
269      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd, double relativeNumberOfEvaluatedRows,
270      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
271      double lowerEstimationLimit, double upperEstimationLimit,
272      double maxPruningRatio, double qualityGainWeight) {
273
274      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows));
275      int originalSize = tree.Size;
276
277      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
278      // tree for branch evaluation
279      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
280      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
281
282      SymbolicExpressionTree prunedTree = tree;
283      double currentQuality = quality.Value;
284      for (int iteration = 0; iteration < iterations; iteration++) {
285        SymbolicExpressionTree iterationBestTree = prunedTree;
286        double bestGain = double.PositiveInfinity;
287        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
288
289        for (int i = 0; i < tournamentSize; i++) {
290          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
291          int clonedTreeSize = clonedTree.Size;
292          var prunePoints = (from node in clonedTree.Root.SubTrees[0].IterateNodesPostfix()
293                             from subTree in node.SubTrees
294                             let subTreeSize = subTree.GetSize()
295                             where subTreeSize <= maxPrunedBranchSize
296                             where clonedTreeSize - subTreeSize >= minPrunedSize
297                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
298                 .ToList();
299          if (prunePoints.Count > 0) {
300            var selectedPrunePoint = prunePoints.SelectRandom(random);
301            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
302            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
303            double branchMean = branchValues.Average();
304            templateTree.Root.SubTrees[0].RemoveSubTree(0);
305
306            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
307            var constNode = CreateConstant(branchMean);
308            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
309
310            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
311        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
312            double prunedSize = clonedTree.Size;
313            // deteriation in quality:
314            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
315            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
316            //      MSE : minimize: newMse / origMse
317            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
318            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
319            //      R²  : minimize: origR² / newR²
320            double qualityDeteriation = maximization ? quality.Value / prunedQuality : prunedQuality / quality.Value;
321            // size of the pruned tree is always smaller than the size of the original tree
322            // same change in quality => prefer pruning operation that removes a larger tree
323            double gain = (qualityDeteriation * qualityGainWeight) /
324                           (originalSize / prunedSize);
325            if (gain < bestGain) {
326              bestGain = gain;
327              iterationBestTree = clonedTree;
328              currentQuality = prunedQuality;
329            }
330          }
331        }
332        prunedTree = iterationBestTree;
333      }
334
335      quality.Value = currentQuality;
336      tree.Root = prunedTree.Root;
337    }
338
339    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
340      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
341      node.Value = constantValue;
342      return node;
343    }
344  }
345}
Note: See TracBrowser for help on using the repository browser.