Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 7214

Last change on this file since 7214 was 7214, checked in by ascheibe, 12 years ago

#1706 adapted outdated plugins to changes in IAnalyzer

File size: 18.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Operators;
29using HeuristicLab.Optimization;
30using HeuristicLab.Parameters;
31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
34
35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
42    private const string EvaluatorParameterName = "Evaluator";
43    private const string MaximizationParameterName = "Maximization";
44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
46    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
47    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
48    private const string TournamentSizeParameterName = "TournamentSize";
49    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
50    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
51    private const string QualityGainWeightParameterName = "QualityGainWeight";
52    private const string IterationsParameterName = "Iterations";
53    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
54    private const string PruningFrequencyParameterName = "PruningFrequency";
55    private const string GenerationParameterName = "Generations";
56    private const string ResultsParameterName = "Results";
57
58    public virtual bool EnabledByDefault {
59      get { return true; }
60    }
61
62    #region parameter properties
63    public ILookupParameter<IRandom> RandomParameter {
64      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
65    }
66    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
67      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
68    }
69    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
70      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
71    }
72    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
73      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
74    }
75    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
76      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
77    }
78    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
79      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
80    }
81    public IValueLookupParameter<IntValue> SamplesStartParameter {
82      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
83    }
84    public IValueLookupParameter<IntValue> SamplesEndParameter {
85      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
86    }
87    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
88      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
89    }
90    public ILookupParameter<BoolValue> MaximizationParameter {
91      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
92    }
93    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
94      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
95    }
96    public IValueLookupParameter<IntValue> TournamentSizeParameter {
97      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
98    }
99    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
100      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
101    }
102    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
103      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
104    }
105    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
106      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
107    }
108    public IValueLookupParameter<IntValue> IterationsParameter {
109      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
110    }
111    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
112      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
113    }
114    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
115      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
116    }
117    public ILookupParameter<IntValue> GenerationParameter {
118      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
119    }
120    public ILookupParameter<ResultCollection> ResultsParameter {
121      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
122    }
123    #endregion
124    #region properties
125    public IRandom Random {
126      get { return RandomParameter.ActualValue; }
127    }
128    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
129      get { return SymbolicExpressionTreeParameter.ActualValue; }
130    }
131    public DataAnalysisProblemData DataAnalysisProblemData {
132      get { return DataAnalysisProblemDataParameter.ActualValue; }
133    }
134    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
135      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
136    }
137    public DoubleValue UpperEstimationLimit {
138      get { return UpperEstimationLimitParameter.ActualValue; }
139    }
140    public DoubleValue LowerEstimationLimit {
141      get { return LowerEstimationLimitParameter.ActualValue; }
142    }
143    public IntValue SamplesStart {
144      get { return SamplesStartParameter.ActualValue; }
145    }
146    public IntValue SamplesEnd {
147      get { return SamplesEndParameter.ActualValue; }
148    }
149    public ISymbolicRegressionEvaluator Evaluator {
150      get { return EvaluatorParameter.ActualValue; }
151    }
152    public BoolValue Maximization {
153      get { return MaximizationParameter.ActualValue; }
154    }
155    public DoubleValue MaxPruningRatio {
156      get { return MaxPruningRatioParameter.ActualValue; }
157    }
158    public IntValue TournamentSize {
159      get { return TournamentSizeParameter.ActualValue; }
160    }
161    public DoubleValue PopulationPercentileStart {
162      get { return PopulationPercentileStartParameter.ActualValue; }
163    }
164    public DoubleValue PopulationPercentileEnd {
165      get { return PopulationPercentileEndParameter.ActualValue; }
166    }
167    public DoubleValue QualityGainWeight {
168      get { return QualityGainWeightParameter.ActualValue; }
169    }
170    public IntValue Iterations {
171      get { return IterationsParameter.ActualValue; }
172    }
173    public IntValue PruningFrequency {
174      get { return PruningFrequencyParameter.ActualValue; }
175    }
176    public IntValue FirstPruningGeneration {
177      get { return FirstPruningGenerationParameter.ActualValue; }
178    }
179    public IntValue Generation {
180      get { return GenerationParameter.ActualValue; }
181    }
182    #endregion
183    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
184    public SymbolicRegressionTournamentPruning()
185      : base() {
186      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
187      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
188      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
189      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
190      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
191      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
192      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
193      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
195      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
196      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
197      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
198      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
199      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
200      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
201      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
202      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
203      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
204      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
205      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
206    }
207
208    protected SymbolicRegressionTournamentPruning(SymbolicRegressionTournamentPruning original, Cloner cloner)
209      : base(original, cloner) {
210    }
211
212    public override IDeepCloneable Clone(Cloner cloner) {
213      return new SymbolicRegressionTournamentPruning(this, cloner);
214    }
215
216    [StorableHook(HookType.AfterDeserialization)]
217    private void AfterDeserialization() {
218      #region compatibility remove before releasing 3.3.1
219      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
220        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
221      }
222      if (!Parameters.ContainsKey(MaximizationParameterName)) {
223        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
224      }
225      #endregion
226    }
227
228    public override IOperation Apply() {
229      bool pruningCondition =
230        (Generation.Value >= FirstPruningGeneration.Value) &&
231        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
232      if (pruningCondition) {
233        int n = SymbolicExpressionTree.Length;
234        double percentileStart = PopulationPercentileStart.Value;
235        double percentileEnd = PopulationPercentileEnd.Value;
236        // for each tree in the given percentile
237        var trees = SymbolicExpressionTree
238          .Skip((int)(n * percentileStart))
239          .Take((int)(n * (percentileEnd - percentileStart)));
240        foreach (var tree in trees) {
241          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
242            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
243            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
244            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
245            MaxPruningRatio.Value, QualityGainWeight.Value);
246        }
247      }
248      return base.Apply();
249    }
250
251    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
252      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
253      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
254      double lowerEstimationLimit, double upperEstimationLimit,
255      double maxPruningRatio, double qualityGainWeight) {
256      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
257        .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
258      int originalSize = tree.Size;
259      double originalQuality = evaluator.Evaluate(interpreter, tree,
260        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
261
262      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
263
264      // tree for branch evaluation
265      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
266      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
267
268      SymbolicExpressionTree prunedTree = tree;
269      for (int iteration = 0; iteration < iterations; iteration++) {
270        SymbolicExpressionTree iterationBestTree = prunedTree;
271        double bestGain = double.PositiveInfinity;
272        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
273
274        for (int i = 0; i < tournamentSize; i++) {
275          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
276          int clonedTreeSize = clonedTree.Size;
277          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
278                             from subTree in node.SubTrees
279                             let subTreeSize = subTree.GetSize()
280                             where subTreeSize <= maxPrunedBranchSize
281                             where clonedTreeSize - subTreeSize >= minPrunedSize
282                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
283                 .ToList();
284          if (prunePoints.Count > 0) {
285            var selectedPrunePoint = prunePoints.SelectRandom(random);
286            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
287            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
288            double branchMean = branchValues.Average();
289            templateTree.Root.SubTrees[0].RemoveSubTree(0);
290
291            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
292            var constNode = CreateConstant(branchMean);
293            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
294
295            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
296        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
297            double prunedSize = clonedTree.Size;
298            // deteriation in quality:
299            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
300            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
301            //      MSE : minimize: newMse / origMse
302            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
303            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
304            //      R²  : minimize: origR² / newR²
305            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
306            // size of the pruned tree is always smaller than the size of the original tree
307            // same change in quality => prefer pruning operation that removes a larger tree
308            double gain = (qualityDeteriation * qualityGainWeight) /
309                           (originalSize / prunedSize);
310            if (gain < bestGain) {
311              bestGain = gain;
312              iterationBestTree = clonedTree;
313            }
314          }
315        }
316        prunedTree = iterationBestTree;
317      }
318      tree.Root = prunedTree.Root;
319    }
320
321    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
322      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
323      node.Value = constantValue;
324      return node;
325    }
326  }
327}
Note: See TracBrowser for help on using the repository browser.