source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4191

Last change on this file since 4191 was 4191, checked in by gkronber, 9 years ago

Changed validation best solution analyzer and tournament pruning operator to use the evaluator specified in the problem parameters. #1117

File size: 18.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33
34namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
35  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
36    private const string RandomParameterName = "Random";
37    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
38    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
39    private const string SamplesStartParameterName = "SamplesStart";
40    private const string SamplesEndParameterName = "SamplesEnd";
41    private const string EvaluatorParameterName = "Evaluator";
42    private const string MaximizationParameterName = "Maximization";
43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
45    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
46    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
47    private const string TournamentSizeParameterName = "TournamentSize";
48    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
49    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
50    private const string QualityGainWeightParameterName = "QualityGainWeight";
51    private const string IterationsParameterName = "Iterations";
52    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
53    private const string PruningFrequencyParameterName = "PruningFrequency";
54    private const string GenerationParameterName = "Generations";
55    private const string ResultsParameterName = "Results";
56
57    #region parameter properties
58    public ILookupParameter<IRandom> RandomParameter {
59      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
60    }
61    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
62      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
63    }
64    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
65      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
66    }
67    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
68      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
75    }
76    public IValueLookupParameter<IntValue> SamplesStartParameter {
77      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
78    }
79    public IValueLookupParameter<IntValue> SamplesEndParameter {
80      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
81    }
82    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
83      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
84    }
85    public ILookupParameter<BoolValue> MaximizationParameter {
86      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
87    }
88    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
90    }
91    public IValueLookupParameter<IntValue> TournamentSizeParameter {
92      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
93    }
94    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
95      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
96    }
97    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
98      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
99    }
100    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
101      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
102    }
103    public IValueLookupParameter<IntValue> IterationsParameter {
104      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
105    }
106    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
107      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
108    }
109    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
110      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
111    }
112    public ILookupParameter<IntValue> GenerationParameter {
113      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
114    }
115    public ILookupParameter<ResultCollection> ResultsParameter {
116      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
117    }
118    #endregion
119    #region properties
120    public IRandom Random {
121      get { return RandomParameter.ActualValue; }
122    }
123    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
124      get { return SymbolicExpressionTreeParameter.ActualValue; }
125    }
126    public DataAnalysisProblemData DataAnalysisProblemData {
127      get { return DataAnalysisProblemDataParameter.ActualValue; }
128    }
129    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
130      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
131    }
132    public DoubleValue UpperEstimationLimit {
133      get { return UpperEstimationLimitParameter.ActualValue; }
134    }
135    public DoubleValue LowerEstimationLimit {
136      get { return LowerEstimationLimitParameter.ActualValue; }
137    }
138    public IntValue SamplesStart {
139      get { return SamplesStartParameter.ActualValue; }
140    }
141    public IntValue SamplesEnd {
142      get { return SamplesEndParameter.ActualValue; }
143    }
144    public ISymbolicRegressionEvaluator Evaluator {
145      get { return EvaluatorParameter.ActualValue; }
146    }
147    public BoolValue Maximization {
148      get { return MaximizationParameter.ActualValue; }
149    }
150    public DoubleValue MaxPruningRatio {
151      get { return MaxPruningRatioParameter.ActualValue; }
152    }
153    public IntValue TournamentSize {
154      get { return TournamentSizeParameter.ActualValue; }
155    }
156    public DoubleValue PopulationPercentileStart {
157      get { return PopulationPercentileStartParameter.ActualValue; }
158    }
159    public DoubleValue PopulationPercentileEnd {
160      get { return PopulationPercentileEndParameter.ActualValue; }
161    }
162    public DoubleValue QualityGainWeight {
163      get { return QualityGainWeightParameter.ActualValue; }
164    }
165    public IntValue Iterations {
166      get { return IterationsParameter.ActualValue; }
167    }
168    public IntValue PruningFrequency {
169      get { return PruningFrequencyParameter.ActualValue; }
170    }
171    public IntValue FirstPruningGeneration {
172      get { return FirstPruningGenerationParameter.ActualValue; }
173    }
174    public IntValue Generation {
175      get { return GenerationParameter.ActualValue; }
176    }
177    #endregion
178    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
179    public SymbolicRegressionTournamentPruning()
180      : base() {
181      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
182      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
183      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
184      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
185      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
186      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
187      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
188      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
189      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
190      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
191      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
192      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
196      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
197      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
198      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
199      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
200      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
201    }
202
203    [StorableHook(HookType.AfterDeserialization)]
204    private void AfterDeserialization() {
205      #region compatibility remove before releasing 3.3.1
206      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
207        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
208      }
209      if (!Parameters.ContainsKey(MaximizationParameterName)) {
210        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
211      }
212      #endregion
213    }
214
215    public override IOperation Apply() {
216      bool pruningCondition =
217        (Generation.Value >= FirstPruningGeneration.Value) &&
218        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
219      if (pruningCondition) {
220        int n = SymbolicExpressionTree.Length;
221        double percentileStart = PopulationPercentileStart.Value;
222        double percentileEnd = PopulationPercentileEnd.Value;
223        // for each tree in the given percentile
224        var trees = SymbolicExpressionTree
225          .Skip((int)(n * percentileStart))
226          .Take((int)(n * (percentileEnd - percentileStart)));
227        foreach (var tree in trees) {
228          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
229            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
230            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
231            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
232            MaxPruningRatio.Value, QualityGainWeight.Value);
233        }
234      }
235      return base.Apply();
236    }
237
238    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
239      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
240      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
241      double lowerEstimationLimit, double upperEstimationLimit,
242      double maxPruningRatio, double qualityGainWeight) {
243      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
244      int originalSize = tree.Size;
245      double originalQuality = evaluator.Evaluate(interpreter, tree,
246        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
247
248      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
249
250      // tree for branch evaluation
251      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
252      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
253
254      SymbolicExpressionTree prunedTree = tree;
255      for (int iteration = 0; iteration < iterations; iteration++) {
256        SymbolicExpressionTree iterationBestTree = prunedTree;
257        double bestGain = double.PositiveInfinity;
258        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
259
260        for (int i = 0; i < tournamentSize; i++) {
261          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
262          int clonedTreeSize = clonedTree.Size;
263          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
264                             from subTree in node.SubTrees
265                             let subTreeSize = subTree.GetSize()
266                             where subTreeSize <= maxPrunedBranchSize
267                             where clonedTreeSize - subTreeSize >= minPrunedSize
268                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
269                 .ToList();
270          if (prunePoints.Count > 0) {
271            var selectedPrunePoint = prunePoints.SelectRandom(random);
272            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
273            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
274            double branchMean = branchValues.Average();
275            templateTree.Root.SubTrees[0].RemoveSubTree(0);
276
277            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
278            var constNode = CreateConstant(branchMean);
279            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
280
281            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
282        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
283            double prunedSize = clonedTree.Size;
284            // deteriation in quality:
285            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
286            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
287            //      MSE : minimize: newMse / origMse
288            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
289            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
290            //      R²  : minimize: origR² / newR²
291            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
292            // size of the pruned tree is always smaller than the size of the original tree
293            // same change in quality => prefer pruning operation that removes a larger tree
294            double gain = (qualityDeteriation * qualityGainWeight) /
295                           (originalSize / prunedSize);
296            if (gain < bestGain) {
297              bestGain = gain;
298              iterationBestTree = clonedTree;
299            }
300          }
301        }
302        prunedTree = iterationBestTree;
303      }
304      tree.Root = prunedTree.Root;
305    }
306
307    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
308      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
309      node.Value = constantValue;
310      return node;
311    }
312  }
313}
Note: See TracBrowser for help on using the repository browser.