Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4502

Last change on this file since 4502 was 4468, checked in by mkommend, 14 years ago

Preparation for cross validation - removed the test samples from the trainining samples and added ValidationPercentage parameter (ticket #1199).

File size: 18.2 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
33
34namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
35  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
36    private const string RandomParameterName = "Random";
37    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
38    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
39    private const string SamplesStartParameterName = "SamplesStart";
40    private const string SamplesEndParameterName = "SamplesEnd";
41    private const string EvaluatorParameterName = "Evaluator";
42    private const string MaximizationParameterName = "Maximization";
43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
45    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
46    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
47    private const string TournamentSizeParameterName = "TournamentSize";
48    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
49    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
50    private const string QualityGainWeightParameterName = "QualityGainWeight";
51    private const string IterationsParameterName = "Iterations";
52    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
53    private const string PruningFrequencyParameterName = "PruningFrequency";
54    private const string GenerationParameterName = "Generations";
55    private const string ResultsParameterName = "Results";
56
57    #region parameter properties
58    public ILookupParameter<IRandom> RandomParameter {
59      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
60    }
61    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
62      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
63    }
64    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
65      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
66    }
67    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
68      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
75    }
76    public IValueLookupParameter<IntValue> SamplesStartParameter {
77      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
78    }
79    public IValueLookupParameter<IntValue> SamplesEndParameter {
80      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
81    }
82    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
83      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
84    }
85    public ILookupParameter<BoolValue> MaximizationParameter {
86      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
87    }
88    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
90    }
91    public IValueLookupParameter<IntValue> TournamentSizeParameter {
92      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
93    }
94    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
95      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
96    }
97    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
98      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
99    }
100    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
101      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
102    }
103    public IValueLookupParameter<IntValue> IterationsParameter {
104      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
105    }
106    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
107      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
108    }
109    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
110      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
111    }
112    public ILookupParameter<IntValue> GenerationParameter {
113      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
114    }
115    public ILookupParameter<ResultCollection> ResultsParameter {
116      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
117    }
118    #endregion
119    #region properties
120    public IRandom Random {
121      get { return RandomParameter.ActualValue; }
122    }
123    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
124      get { return SymbolicExpressionTreeParameter.ActualValue; }
125    }
126    public DataAnalysisProblemData DataAnalysisProblemData {
127      get { return DataAnalysisProblemDataParameter.ActualValue; }
128    }
129    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
130      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
131    }
132    public DoubleValue UpperEstimationLimit {
133      get { return UpperEstimationLimitParameter.ActualValue; }
134    }
135    public DoubleValue LowerEstimationLimit {
136      get { return LowerEstimationLimitParameter.ActualValue; }
137    }
138    public IntValue SamplesStart {
139      get { return SamplesStartParameter.ActualValue; }
140    }
141    public IntValue SamplesEnd {
142      get { return SamplesEndParameter.ActualValue; }
143    }
144    public ISymbolicRegressionEvaluator Evaluator {
145      get { return EvaluatorParameter.ActualValue; }
146    }
147    public BoolValue Maximization {
148      get { return MaximizationParameter.ActualValue; }
149    }
150    public DoubleValue MaxPruningRatio {
151      get { return MaxPruningRatioParameter.ActualValue; }
152    }
153    public IntValue TournamentSize {
154      get { return TournamentSizeParameter.ActualValue; }
155    }
156    public DoubleValue PopulationPercentileStart {
157      get { return PopulationPercentileStartParameter.ActualValue; }
158    }
159    public DoubleValue PopulationPercentileEnd {
160      get { return PopulationPercentileEndParameter.ActualValue; }
161    }
162    public DoubleValue QualityGainWeight {
163      get { return QualityGainWeightParameter.ActualValue; }
164    }
165    public IntValue Iterations {
166      get { return IterationsParameter.ActualValue; }
167    }
168    public IntValue PruningFrequency {
169      get { return PruningFrequencyParameter.ActualValue; }
170    }
171    public IntValue FirstPruningGeneration {
172      get { return FirstPruningGenerationParameter.ActualValue; }
173    }
174    public IntValue Generation {
175      get { return GenerationParameter.ActualValue; }
176    }
177    #endregion
178    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
179    public SymbolicRegressionTournamentPruning()
180      : base() {
181      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
182      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
183      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
184      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
185      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
186      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
187      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
188      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
189      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
190      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
191      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
192      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
196      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
197      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
198      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
199      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
200      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
201    }
202
203    [StorableHook(HookType.AfterDeserialization)]
204    private void AfterDeserialization() {
205      #region compatibility remove before releasing 3.3.1
206      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
207        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
208      }
209      if (!Parameters.ContainsKey(MaximizationParameterName)) {
210        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
211      }
212      #endregion
213    }
214
215    public override IOperation Apply() {
216      bool pruningCondition =
217        (Generation.Value >= FirstPruningGeneration.Value) &&
218        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
219      if (pruningCondition) {
220        int n = SymbolicExpressionTree.Length;
221        double percentileStart = PopulationPercentileStart.Value;
222        double percentileEnd = PopulationPercentileEnd.Value;
223        // for each tree in the given percentile
224        var trees = SymbolicExpressionTree
225          .Skip((int)(n * percentileStart))
226          .Take((int)(n * (percentileEnd - percentileStart)));
227        foreach (var tree in trees) {
228          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
229            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
230            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
231            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
232            MaxPruningRatio.Value, QualityGainWeight.Value);
233        }
234      }
235      return base.Apply();
236    }
237
238    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
239      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
240      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
241      double lowerEstimationLimit, double upperEstimationLimit,
242      double maxPruningRatio, double qualityGainWeight) {
243        IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
244          .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
245      int originalSize = tree.Size;
246      double originalQuality = evaluator.Evaluate(interpreter, tree,
247        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
248
249      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
250
251      // tree for branch evaluation
252      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
253      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
254
255      SymbolicExpressionTree prunedTree = tree;
256      for (int iteration = 0; iteration < iterations; iteration++) {
257        SymbolicExpressionTree iterationBestTree = prunedTree;
258        double bestGain = double.PositiveInfinity;
259        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
260
261        for (int i = 0; i < tournamentSize; i++) {
262          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
263          int clonedTreeSize = clonedTree.Size;
264          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
265                             from subTree in node.SubTrees
266                             let subTreeSize = subTree.GetSize()
267                             where subTreeSize <= maxPrunedBranchSize
268                             where clonedTreeSize - subTreeSize >= minPrunedSize
269                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
270                 .ToList();
271          if (prunePoints.Count > 0) {
272            var selectedPrunePoint = prunePoints.SelectRandom(random);
273            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
274            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
275            double branchMean = branchValues.Average();
276            templateTree.Root.SubTrees[0].RemoveSubTree(0);
277
278            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
279            var constNode = CreateConstant(branchMean);
280            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
281
282            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
283        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
284            double prunedSize = clonedTree.Size;
285            // deteriation in quality:
286            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
287            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
288            //      MSE : minimize: newMse / origMse
289            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
290            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
291            //      R²  : minimize: origR² / newR²
292            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
293            // size of the pruned tree is always smaller than the size of the original tree
294            // same change in quality => prefer pruning operation that removes a larger tree
295            double gain = (qualityDeteriation * qualityGainWeight) /
296                           (originalSize / prunedSize);
297            if (gain < bestGain) {
298              bestGain = gain;
299              iterationBestTree = clonedTree;
300            }
301          }
302        }
303        prunedTree = iterationBestTree;
304      }
305      tree.Root = prunedTree.Root;
306    }
307
308    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
309      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
310      node.Value = constantValue;
311      return node;
312    }
313  }
314}
Note: See TracBrowser for help on using the repository browser.