Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ParameterBinding/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 6479

Last change on this file since 6479 was 4722, checked in by swagner, 14 years ago

Merged cloning refactoring branch back into trunk (#922)

File size: 18.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public sealed class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
42    private const string EvaluatorParameterName = "Evaluator";
43    private const string MaximizationParameterName = "Maximization";
44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
46    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
47    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
48    private const string TournamentSizeParameterName = "TournamentSize";
49    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
50    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
51    private const string QualityGainWeightParameterName = "QualityGainWeight";
52    private const string IterationsParameterName = "Iterations";
53    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
54    private const string PruningFrequencyParameterName = "PruningFrequency";
55    private const string GenerationParameterName = "Generations";
56    private const string ResultsParameterName = "Results";
57
58    #region parameter properties
59    public ILookupParameter<IRandom> RandomParameter {
60      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
61    }
62    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
63      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
64    }
65    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
66      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
67    }
68    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
69      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
70    }
71    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
72      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
73    }
74    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
75      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
76    }
77    public IValueLookupParameter<IntValue> SamplesStartParameter {
78      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
79    }
80    public IValueLookupParameter<IntValue> SamplesEndParameter {
81      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
82    }
83    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
84      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
85    }
86    public ILookupParameter<BoolValue> MaximizationParameter {
87      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
88    }
89    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
90      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
91    }
92    public IValueLookupParameter<IntValue> TournamentSizeParameter {
93      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
94    }
95    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
96      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
97    }
98    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
99      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
100    }
101    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
102      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
103    }
104    public IValueLookupParameter<IntValue> IterationsParameter {
105      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
106    }
107    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
108      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
109    }
110    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
111      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
112    }
113    public ILookupParameter<IntValue> GenerationParameter {
114      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
115    }
116    public ILookupParameter<ResultCollection> ResultsParameter {
117      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
118    }
119    #endregion
120    #region properties
121    public IRandom Random {
122      get { return RandomParameter.ActualValue; }
123    }
124    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
125      get { return SymbolicExpressionTreeParameter.ActualValue; }
126    }
127    public DataAnalysisProblemData DataAnalysisProblemData {
128      get { return DataAnalysisProblemDataParameter.ActualValue; }
129    }
130    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
131      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
132    }
133    public DoubleValue UpperEstimationLimit {
134      get { return UpperEstimationLimitParameter.ActualValue; }
135    }
136    public DoubleValue LowerEstimationLimit {
137      get { return LowerEstimationLimitParameter.ActualValue; }
138    }
139    public IntValue SamplesStart {
140      get { return SamplesStartParameter.ActualValue; }
141    }
142    public IntValue SamplesEnd {
143      get { return SamplesEndParameter.ActualValue; }
144    }
145    public ISymbolicRegressionEvaluator Evaluator {
146      get { return EvaluatorParameter.ActualValue; }
147    }
148    public BoolValue Maximization {
149      get { return MaximizationParameter.ActualValue; }
150    }
151    public DoubleValue MaxPruningRatio {
152      get { return MaxPruningRatioParameter.ActualValue; }
153    }
154    public IntValue TournamentSize {
155      get { return TournamentSizeParameter.ActualValue; }
156    }
157    public DoubleValue PopulationPercentileStart {
158      get { return PopulationPercentileStartParameter.ActualValue; }
159    }
160    public DoubleValue PopulationPercentileEnd {
161      get { return PopulationPercentileEndParameter.ActualValue; }
162    }
163    public DoubleValue QualityGainWeight {
164      get { return QualityGainWeightParameter.ActualValue; }
165    }
166    public IntValue Iterations {
167      get { return IterationsParameter.ActualValue; }
168    }
169    public IntValue PruningFrequency {
170      get { return PruningFrequencyParameter.ActualValue; }
171    }
172    public IntValue FirstPruningGeneration {
173      get { return FirstPruningGenerationParameter.ActualValue; }
174    }
175    public IntValue Generation {
176      get { return GenerationParameter.ActualValue; }
177    }
178    #endregion
179
180    [StorableConstructor]
181    private SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
182    private SymbolicRegressionTournamentPruning(SymbolicRegressionTournamentPruning original, Cloner cloner) : base(original, cloner) { }
183    public SymbolicRegressionTournamentPruning()
184      : base() {
185      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
186      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
187      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
188      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
189      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
190      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
191      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
192      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
194      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
196      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
197      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
198      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
199      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
200      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
201      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
202      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
203      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
204      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
205    }
206
207    public override IDeepCloneable Clone(Cloner cloner) {
208      return new SymbolicRegressionTournamentPruning(this, cloner);
209    }
210
211    [StorableHook(HookType.AfterDeserialization)]
212    private void AfterDeserialization() {
213      #region compatibility remove before releasing 3.3.1
214      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
215        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
216      }
217      if (!Parameters.ContainsKey(MaximizationParameterName)) {
218        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
219      }
220      #endregion
221    }
222
223    public override IOperation Apply() {
224      bool pruningCondition =
225        (Generation.Value >= FirstPruningGeneration.Value) &&
226        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
227      if (pruningCondition) {
228        int n = SymbolicExpressionTree.Length;
229        double percentileStart = PopulationPercentileStart.Value;
230        double percentileEnd = PopulationPercentileEnd.Value;
231        // for each tree in the given percentile
232        var trees = SymbolicExpressionTree
233          .Skip((int)(n * percentileStart))
234          .Take((int)(n * (percentileEnd - percentileStart)));
235        foreach (var tree in trees) {
236          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
237            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
238            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
239            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
240            MaxPruningRatio.Value, QualityGainWeight.Value);
241        }
242      }
243      return base.Apply();
244    }
245
246    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
247      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
248      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
249      double lowerEstimationLimit, double upperEstimationLimit,
250      double maxPruningRatio, double qualityGainWeight) {
251      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
252        .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
253      int originalSize = tree.Size;
254      double originalQuality = evaluator.Evaluate(interpreter, tree,
255        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
256
257      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
258
259      // tree for branch evaluation
260      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
261      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
262
263      SymbolicExpressionTree prunedTree = tree;
264      for (int iteration = 0; iteration < iterations; iteration++) {
265        SymbolicExpressionTree iterationBestTree = prunedTree;
266        double bestGain = double.PositiveInfinity;
267        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
268
269        for (int i = 0; i < tournamentSize; i++) {
270          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
271          int clonedTreeSize = clonedTree.Size;
272          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
273                             from subTree in node.SubTrees
274                             let subTreeSize = subTree.GetSize()
275                             where subTreeSize <= maxPrunedBranchSize
276                             where clonedTreeSize - subTreeSize >= minPrunedSize
277                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
278                 .ToList();
279          if (prunePoints.Count > 0) {
280            var selectedPrunePoint = prunePoints.SelectRandom(random);
281            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
282            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
283            double branchMean = branchValues.Average();
284            templateTree.Root.SubTrees[0].RemoveSubTree(0);
285
286            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
287            var constNode = CreateConstant(branchMean);
288            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
289
290            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
291        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
292            double prunedSize = clonedTree.Size;
293            // deteriation in quality:
294            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
295            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
296            //      MSE : minimize: newMse / origMse
297            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
298            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
299            //      R²  : minimize: origR² / newR²
300            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
301            // size of the pruned tree is always smaller than the size of the original tree
302            // same change in quality => prefer pruning operation that removes a larger tree
303            double gain = (qualityDeteriation * qualityGainWeight) /
304                           (originalSize / prunedSize);
305            if (gain < bestGain) {
306              bestGain = gain;
307              iterationBestTree = clonedTree;
308            }
309          }
310        }
311        prunedTree = iterationBestTree;
312      }
313      tree.Root = prunedTree.Root;
314    }
315
316    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
317      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
318      node.Value = constantValue;
319      return node;
320    }
321  }
322}
Note: See TracBrowser for help on using the repository browser.