Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 6996

Last change on this file since 6996 was 6739, checked in by mkommend, 13 years ago

#1579: Readded SymbolicRegressionTournamentPruning-Analyzer to Problems.DataAnalysis.Regression-3.3 for backwards compatibility reassons.

File size: 18.5 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
42    private const string EvaluatorParameterName = "Evaluator";
43    private const string MaximizationParameterName = "Maximization";
44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
46    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
47    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
48    private const string TournamentSizeParameterName = "TournamentSize";
49    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
50    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
51    private const string QualityGainWeightParameterName = "QualityGainWeight";
52    private const string IterationsParameterName = "Iterations";
53    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
54    private const string PruningFrequencyParameterName = "PruningFrequency";
55    private const string GenerationParameterName = "Generations";
56    private const string ResultsParameterName = "Results";
57
58    #region parameter properties
59    public ILookupParameter<IRandom> RandomParameter {
60      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
61    }
62    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
63      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
64    }
65    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
66      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
67    }
68    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
69      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
70    }
71    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
72      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
73    }
74    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
75      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
76    }
77    public IValueLookupParameter<IntValue> SamplesStartParameter {
78      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
79    }
80    public IValueLookupParameter<IntValue> SamplesEndParameter {
81      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
82    }
83    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
84      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
85    }
86    public ILookupParameter<BoolValue> MaximizationParameter {
87      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
88    }
89    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
90      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
91    }
92    public IValueLookupParameter<IntValue> TournamentSizeParameter {
93      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
94    }
95    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
96      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
97    }
98    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
99      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
100    }
101    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
102      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
103    }
104    public IValueLookupParameter<IntValue> IterationsParameter {
105      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
106    }
107    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
108      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
109    }
110    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
111      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
112    }
113    public ILookupParameter<IntValue> GenerationParameter {
114      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
115    }
116    public ILookupParameter<ResultCollection> ResultsParameter {
117      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
118    }
119    #endregion
120    #region properties
121    public IRandom Random {
122      get { return RandomParameter.ActualValue; }
123    }
124    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
125      get { return SymbolicExpressionTreeParameter.ActualValue; }
126    }
127    public DataAnalysisProblemData DataAnalysisProblemData {
128      get { return DataAnalysisProblemDataParameter.ActualValue; }
129    }
130    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
131      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
132    }
133    public DoubleValue UpperEstimationLimit {
134      get { return UpperEstimationLimitParameter.ActualValue; }
135    }
136    public DoubleValue LowerEstimationLimit {
137      get { return LowerEstimationLimitParameter.ActualValue; }
138    }
139    public IntValue SamplesStart {
140      get { return SamplesStartParameter.ActualValue; }
141    }
142    public IntValue SamplesEnd {
143      get { return SamplesEndParameter.ActualValue; }
144    }
145    public ISymbolicRegressionEvaluator Evaluator {
146      get { return EvaluatorParameter.ActualValue; }
147    }
148    public BoolValue Maximization {
149      get { return MaximizationParameter.ActualValue; }
150    }
151    public DoubleValue MaxPruningRatio {
152      get { return MaxPruningRatioParameter.ActualValue; }
153    }
154    public IntValue TournamentSize {
155      get { return TournamentSizeParameter.ActualValue; }
156    }
157    public DoubleValue PopulationPercentileStart {
158      get { return PopulationPercentileStartParameter.ActualValue; }
159    }
160    public DoubleValue PopulationPercentileEnd {
161      get { return PopulationPercentileEndParameter.ActualValue; }
162    }
163    public DoubleValue QualityGainWeight {
164      get { return QualityGainWeightParameter.ActualValue; }
165    }
166    public IntValue Iterations {
167      get { return IterationsParameter.ActualValue; }
168    }
169    public IntValue PruningFrequency {
170      get { return PruningFrequencyParameter.ActualValue; }
171    }
172    public IntValue FirstPruningGeneration {
173      get { return FirstPruningGenerationParameter.ActualValue; }
174    }
175    public IntValue Generation {
176      get { return GenerationParameter.ActualValue; }
177    }
178    #endregion
179    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
180    public SymbolicRegressionTournamentPruning()
181      : base() {
182      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
183      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
184      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
185      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
186      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
187      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
188      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
189      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
190      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
191      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
192      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
196      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
197      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
198      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
199      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
200      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
201      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
202    }
203
204    protected SymbolicRegressionTournamentPruning(SymbolicRegressionTournamentPruning original, Cloner cloner)
205      : base(original, cloner) {
206    }
207
208    public override IDeepCloneable Clone(Cloner cloner) {
209      return new SymbolicRegressionTournamentPruning(this, cloner);
210    }
211
212    [StorableHook(HookType.AfterDeserialization)]
213    private void AfterDeserialization() {
214      #region compatibility remove before releasing 3.3.1
215      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
216        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
217      }
218      if (!Parameters.ContainsKey(MaximizationParameterName)) {
219        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
220      }
221      #endregion
222    }
223
224    public override IOperation Apply() {
225      bool pruningCondition =
226        (Generation.Value >= FirstPruningGeneration.Value) &&
227        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
228      if (pruningCondition) {
229        int n = SymbolicExpressionTree.Length;
230        double percentileStart = PopulationPercentileStart.Value;
231        double percentileEnd = PopulationPercentileEnd.Value;
232        // for each tree in the given percentile
233        var trees = SymbolicExpressionTree
234          .Skip((int)(n * percentileStart))
235          .Take((int)(n * (percentileEnd - percentileStart)));
236        foreach (var tree in trees) {
237          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
238            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
239            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
240            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
241            MaxPruningRatio.Value, QualityGainWeight.Value);
242        }
243      }
244      return base.Apply();
245    }
246
247    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
248      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
249      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
250      double lowerEstimationLimit, double upperEstimationLimit,
251      double maxPruningRatio, double qualityGainWeight) {
252      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
253        .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
254      int originalSize = tree.Size;
255      double originalQuality = evaluator.Evaluate(interpreter, tree,
256        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
257
258      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
259
260      // tree for branch evaluation
261      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
262      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
263
264      SymbolicExpressionTree prunedTree = tree;
265      for (int iteration = 0; iteration < iterations; iteration++) {
266        SymbolicExpressionTree iterationBestTree = prunedTree;
267        double bestGain = double.PositiveInfinity;
268        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
269
270        for (int i = 0; i < tournamentSize; i++) {
271          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
272          int clonedTreeSize = clonedTree.Size;
273          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
274                             from subTree in node.SubTrees
275                             let subTreeSize = subTree.GetSize()
276                             where subTreeSize <= maxPrunedBranchSize
277                             where clonedTreeSize - subTreeSize >= minPrunedSize
278                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
279                 .ToList();
280          if (prunePoints.Count > 0) {
281            var selectedPrunePoint = prunePoints.SelectRandom(random);
282            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
283            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
284            double branchMean = branchValues.Average();
285            templateTree.Root.SubTrees[0].RemoveSubTree(0);
286
287            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
288            var constNode = CreateConstant(branchMean);
289            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
290
291            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
292        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
293            double prunedSize = clonedTree.Size;
294            // deteriation in quality:
295            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
296            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
297            //      MSE : minimize: newMse / origMse
298            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
299            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
300            //      R²  : minimize: origR² / newR²
301            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
302            // size of the pruned tree is always smaller than the size of the original tree
303            // same change in quality => prefer pruning operation that removes a larger tree
304            double gain = (qualityDeteriation * qualityGainWeight) /
305                           (originalSize / prunedSize);
306            if (gain < bestGain) {
307              bestGain = gain;
308              iterationBestTree = clonedTree;
309            }
310          }
311        }
312        prunedTree = iterationBestTree;
313      }
314      tree.Root = prunedTree.Root;
315    }
316
317    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
318      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
319      node.Value = constantValue;
320      return node;
321    }
322  }
323}
Note: See TracBrowser for help on using the repository browser.