Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4468

Last change on this file since 4468 was 4468, checked in by mkommend, 14 years ago

Preparation for cross validation - removed the test samples from the trainining samples and added ValidationPercentage parameter (ticket #1199).

File size: 18.2 KB
RevLine 
[3874]1#region License Information
2/* HeuristicLab
[4028]3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[3874]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
[4068]26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[4028]27using HeuristicLab.Operators;
[4068]28using HeuristicLab.Optimization;
[4028]29using HeuristicLab.Parameters;
[4468]30using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[4028]31using HeuristicLab.Problems.DataAnalysis.Symbolic;
32using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
[3874]33
[4028]34namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
35  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
36    private const string RandomParameterName = "Random";
37    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
38    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
39    private const string SamplesStartParameterName = "SamplesStart";
40    private const string SamplesEndParameterName = "SamplesEnd";
[4191]41    private const string EvaluatorParameterName = "Evaluator";
42    private const string MaximizationParameterName = "Maximization";
[4028]43    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
44    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
45    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
46    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
47    private const string TournamentSizeParameterName = "TournamentSize";
48    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
49    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
50    private const string QualityGainWeightParameterName = "QualityGainWeight";
51    private const string IterationsParameterName = "Iterations";
52    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
53    private const string PruningFrequencyParameterName = "PruningFrequency";
54    private const string GenerationParameterName = "Generations";
55    private const string ResultsParameterName = "Results";
56
57    #region parameter properties
58    public ILookupParameter<IRandom> RandomParameter {
59      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
60    }
61    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
62      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
63    }
64    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
65      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
66    }
67    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
68      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
72    }
73    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
74      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
75    }
76    public IValueLookupParameter<IntValue> SamplesStartParameter {
77      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
78    }
79    public IValueLookupParameter<IntValue> SamplesEndParameter {
80      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
81    }
[4191]82    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
83      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
84    }
85    public ILookupParameter<BoolValue> MaximizationParameter {
86      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
87    }
[4028]88    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
90    }
91    public IValueLookupParameter<IntValue> TournamentSizeParameter {
92      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
93    }
94    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
95      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
96    }
97    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
98      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
99    }
100    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
101      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
102    }
103    public IValueLookupParameter<IntValue> IterationsParameter {
104      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
105    }
106    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
107      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
108    }
109    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
110      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
111    }
112    public ILookupParameter<IntValue> GenerationParameter {
113      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
114    }
115    public ILookupParameter<ResultCollection> ResultsParameter {
116      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
117    }
118    #endregion
119    #region properties
120    public IRandom Random {
121      get { return RandomParameter.ActualValue; }
122    }
123    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
124      get { return SymbolicExpressionTreeParameter.ActualValue; }
125    }
126    public DataAnalysisProblemData DataAnalysisProblemData {
127      get { return DataAnalysisProblemDataParameter.ActualValue; }
128    }
129    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
130      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
131    }
132    public DoubleValue UpperEstimationLimit {
133      get { return UpperEstimationLimitParameter.ActualValue; }
134    }
135    public DoubleValue LowerEstimationLimit {
136      get { return LowerEstimationLimitParameter.ActualValue; }
137    }
138    public IntValue SamplesStart {
139      get { return SamplesStartParameter.ActualValue; }
140    }
141    public IntValue SamplesEnd {
142      get { return SamplesEndParameter.ActualValue; }
143    }
[4191]144    public ISymbolicRegressionEvaluator Evaluator {
145      get { return EvaluatorParameter.ActualValue; }
146    }
147    public BoolValue Maximization {
148      get { return MaximizationParameter.ActualValue; }
149    }
[4028]150    public DoubleValue MaxPruningRatio {
151      get { return MaxPruningRatioParameter.ActualValue; }
152    }
153    public IntValue TournamentSize {
154      get { return TournamentSizeParameter.ActualValue; }
155    }
156    public DoubleValue PopulationPercentileStart {
157      get { return PopulationPercentileStartParameter.ActualValue; }
158    }
159    public DoubleValue PopulationPercentileEnd {
160      get { return PopulationPercentileEndParameter.ActualValue; }
161    }
162    public DoubleValue QualityGainWeight {
163      get { return QualityGainWeightParameter.ActualValue; }
164    }
165    public IntValue Iterations {
166      get { return IterationsParameter.ActualValue; }
167    }
168    public IntValue PruningFrequency {
169      get { return PruningFrequencyParameter.ActualValue; }
170    }
171    public IntValue FirstPruningGeneration {
172      get { return FirstPruningGenerationParameter.ActualValue; }
173    }
174    public IntValue Generation {
175      get { return GenerationParameter.ActualValue; }
176    }
177    #endregion
[4191]178    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
[4028]179    public SymbolicRegressionTournamentPruning()
[3874]180      : base() {
[4028]181      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
182      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
183      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
184      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
185      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
186      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
[4191]187      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
188      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
[4028]189      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
190      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
191      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
192      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
196      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
197      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
198      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
199      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
200      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
[3874]201    }
202
[4191]203    [StorableHook(HookType.AfterDeserialization)]
204    private void AfterDeserialization() {
205      #region compatibility remove before releasing 3.3.1
206      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
207        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
208      }
209      if (!Parameters.ContainsKey(MaximizationParameterName)) {
210        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
211      }
212      #endregion
213    }
214
[4028]215    public override IOperation Apply() {
216      bool pruningCondition =
217        (Generation.Value >= FirstPruningGeneration.Value) &&
218        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
219      if (pruningCondition) {
220        int n = SymbolicExpressionTree.Length;
221        double percentileStart = PopulationPercentileStart.Value;
222        double percentileEnd = PopulationPercentileEnd.Value;
223        // for each tree in the given percentile
224        var trees = SymbolicExpressionTree
225          .Skip((int)(n * percentileStart))
226          .Take((int)(n * (percentileEnd - percentileStart)));
227        foreach (var tree in trees) {
228          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
229            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
[4191]230            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
[4028]231            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
232            MaxPruningRatio.Value, QualityGainWeight.Value);
233        }
[3874]234      }
[4028]235      return base.Apply();
[3874]236    }
237
[4028]238    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
239      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
[4191]240      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
[4028]241      double lowerEstimationLimit, double upperEstimationLimit,
[3874]242      double maxPruningRatio, double qualityGainWeight) {
[4468]243        IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
244          .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
[4028]245      int originalSize = tree.Size;
[4191]246      double originalQuality = evaluator.Evaluate(interpreter, tree,
247        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
[3874]248
[4028]249      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
[3874]250
[4028]251      // tree for branch evaluation
252      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
253      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
[3874]254
[4028]255      SymbolicExpressionTree prunedTree = tree;
256      for (int iteration = 0; iteration < iterations; iteration++) {
257        SymbolicExpressionTree iterationBestTree = prunedTree;
258        double bestGain = double.PositiveInfinity;
259        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
[3874]260
[4028]261        for (int i = 0; i < tournamentSize; i++) {
262          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
263          int clonedTreeSize = clonedTree.Size;
264          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
265                             from subTree in node.SubTrees
266                             let subTreeSize = subTree.GetSize()
267                             where subTreeSize <= maxPrunedBranchSize
268                             where clonedTreeSize - subTreeSize >= minPrunedSize
269                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
270                 .ToList();
271          if (prunePoints.Count > 0) {
272            var selectedPrunePoint = prunePoints.SelectRandom(random);
273            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
274            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
275            double branchMean = branchValues.Average();
276            templateTree.Root.SubTrees[0].RemoveSubTree(0);
[3874]277
[4028]278            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
279            var constNode = CreateConstant(branchMean);
280            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
[3874]281
[4191]282            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
[4034]283        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
[4028]284            double prunedSize = clonedTree.Size;
[4191]285            // deteriation in quality:
286            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
287            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
288            //      MSE : minimize: newMse / origMse
289            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
290            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
291            //      R²  : minimize: origR² / newR²
292            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
[4028]293            // size of the pruned tree is always smaller than the size of the original tree
294            // same change in quality => prefer pruning operation that removes a larger tree
[4191]295            double gain = (qualityDeteriation * qualityGainWeight) /
[4028]296                           (originalSize / prunedSize);
297            if (gain < bestGain) {
298              bestGain = gain;
299              iterationBestTree = clonedTree;
300            }
301          }
[3874]302        }
[4028]303        prunedTree = iterationBestTree;
[3874]304      }
[4028]305      tree.Root = prunedTree.Root;
[3874]306    }
307
[4028]308    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
309      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
[3874]310      node.Value = constantValue;
311      return node;
312    }
313  }
314}
Note: See TracBrowser for help on using the repository browser.