Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 6739

Last change on this file since 6739 was 6739, checked in by mkommend, 13 years ago

#1579: Readded SymbolicRegressionTournamentPruning-Analyzer to Problems.DataAnalysis.Regression-3.3 for backwards compatibility reassons.

File size: 18.5 KB
RevLine 
[3874]1#region License Information
2/* HeuristicLab
[4028]3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[3874]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
[6739]24using HeuristicLab.Common;
[3874]25using HeuristicLab.Core;
26using HeuristicLab.Data;
[4068]27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[4028]28using HeuristicLab.Operators;
[4068]29using HeuristicLab.Optimization;
[4028]30using HeuristicLab.Parameters;
[4468]31using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
[4028]32using HeuristicLab.Problems.DataAnalysis.Symbolic;
33using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
[3874]34
[4028]35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
[4191]42    private const string EvaluatorParameterName = "Evaluator";
43    private const string MaximizationParameterName = "Maximization";
[4028]44    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
45    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
46    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
47    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
48    private const string TournamentSizeParameterName = "TournamentSize";
49    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
50    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
51    private const string QualityGainWeightParameterName = "QualityGainWeight";
52    private const string IterationsParameterName = "Iterations";
53    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
54    private const string PruningFrequencyParameterName = "PruningFrequency";
55    private const string GenerationParameterName = "Generations";
56    private const string ResultsParameterName = "Results";
57
58    #region parameter properties
59    public ILookupParameter<IRandom> RandomParameter {
60      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
61    }
62    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
63      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
64    }
65    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
66      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
67    }
68    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
69      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
70    }
71    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
72      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
73    }
74    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
75      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
76    }
77    public IValueLookupParameter<IntValue> SamplesStartParameter {
78      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
79    }
80    public IValueLookupParameter<IntValue> SamplesEndParameter {
81      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
82    }
[4191]83    public ILookupParameter<ISymbolicRegressionEvaluator> EvaluatorParameter {
84      get { return (ILookupParameter<ISymbolicRegressionEvaluator>)Parameters[EvaluatorParameterName]; }
85    }
86    public ILookupParameter<BoolValue> MaximizationParameter {
87      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
88    }
[4028]89    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
90      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
91    }
92    public IValueLookupParameter<IntValue> TournamentSizeParameter {
93      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
94    }
95    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
96      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
97    }
98    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
99      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
100    }
101    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
102      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
103    }
104    public IValueLookupParameter<IntValue> IterationsParameter {
105      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
106    }
107    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
108      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
109    }
110    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
111      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
112    }
113    public ILookupParameter<IntValue> GenerationParameter {
114      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
115    }
116    public ILookupParameter<ResultCollection> ResultsParameter {
117      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
118    }
119    #endregion
120    #region properties
121    public IRandom Random {
122      get { return RandomParameter.ActualValue; }
123    }
124    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
125      get { return SymbolicExpressionTreeParameter.ActualValue; }
126    }
127    public DataAnalysisProblemData DataAnalysisProblemData {
128      get { return DataAnalysisProblemDataParameter.ActualValue; }
129    }
130    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
131      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
132    }
133    public DoubleValue UpperEstimationLimit {
134      get { return UpperEstimationLimitParameter.ActualValue; }
135    }
136    public DoubleValue LowerEstimationLimit {
137      get { return LowerEstimationLimitParameter.ActualValue; }
138    }
139    public IntValue SamplesStart {
140      get { return SamplesStartParameter.ActualValue; }
141    }
142    public IntValue SamplesEnd {
143      get { return SamplesEndParameter.ActualValue; }
144    }
[4191]145    public ISymbolicRegressionEvaluator Evaluator {
146      get { return EvaluatorParameter.ActualValue; }
147    }
148    public BoolValue Maximization {
149      get { return MaximizationParameter.ActualValue; }
150    }
[4028]151    public DoubleValue MaxPruningRatio {
152      get { return MaxPruningRatioParameter.ActualValue; }
153    }
154    public IntValue TournamentSize {
155      get { return TournamentSizeParameter.ActualValue; }
156    }
157    public DoubleValue PopulationPercentileStart {
158      get { return PopulationPercentileStartParameter.ActualValue; }
159    }
160    public DoubleValue PopulationPercentileEnd {
161      get { return PopulationPercentileEndParameter.ActualValue; }
162    }
163    public DoubleValue QualityGainWeight {
164      get { return QualityGainWeightParameter.ActualValue; }
165    }
166    public IntValue Iterations {
167      get { return IterationsParameter.ActualValue; }
168    }
169    public IntValue PruningFrequency {
170      get { return PruningFrequencyParameter.ActualValue; }
171    }
172    public IntValue FirstPruningGeneration {
173      get { return FirstPruningGenerationParameter.ActualValue; }
174    }
175    public IntValue Generation {
176      get { return GenerationParameter.ActualValue; }
177    }
178    #endregion
[4191]179    protected SymbolicRegressionTournamentPruning(bool deserializing) : base(deserializing) { }
[4028]180    public SymbolicRegressionTournamentPruning()
[3874]181      : base() {
[4028]182      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
183      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
184      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
185      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
186      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
187      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
[4191]188      Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator that should be used to determine which branches are not relevant."));
189      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
[4028]190      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
191      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
192      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
193      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
194      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
195      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
196      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
197      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
198      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
199      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
200      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
201      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
[3874]202    }
203
[6739]204    protected SymbolicRegressionTournamentPruning(SymbolicRegressionTournamentPruning original, Cloner cloner)
205      : base(original, cloner) {
206    }
207
208    public override IDeepCloneable Clone(Cloner cloner) {
209      return new SymbolicRegressionTournamentPruning(this, cloner);
210    }
211
[4191]212    [StorableHook(HookType.AfterDeserialization)]
213    private void AfterDeserialization() {
214      #region compatibility remove before releasing 3.3.1
215      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
216        Parameters.Add(new LookupParameter<ISymbolicRegressionEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
217      }
218      if (!Parameters.ContainsKey(MaximizationParameterName)) {
219        Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
220      }
221      #endregion
222    }
223
[4028]224    public override IOperation Apply() {
225      bool pruningCondition =
226        (Generation.Value >= FirstPruningGeneration.Value) &&
227        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
228      if (pruningCondition) {
229        int n = SymbolicExpressionTree.Length;
230        double percentileStart = PopulationPercentileStart.Value;
231        double percentileEnd = PopulationPercentileEnd.Value;
232        // for each tree in the given percentile
233        var trees = SymbolicExpressionTree
234          .Skip((int)(n * percentileStart))
235          .Take((int)(n * (percentileEnd - percentileStart)));
236        foreach (var tree in trees) {
237          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
238            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
[4191]239            SymbolicExpressionTreeInterpreter, Evaluator, Maximization.Value,
[4028]240            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
241            MaxPruningRatio.Value, QualityGainWeight.Value);
242        }
[3874]243      }
[4028]244      return base.Apply();
[3874]245    }
246
[4028]247    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
248      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
[4191]249      ISymbolicExpressionTreeInterpreter interpreter, ISymbolicRegressionEvaluator evaluator, bool maximization,
[4028]250      double lowerEstimationLimit, double upperEstimationLimit,
[3874]251      double maxPruningRatio, double qualityGainWeight) {
[6739]252      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart)
253        .Where(i => i < problemData.TestSamplesStart.Value || problemData.TestSamplesEnd.Value <= i);
[4028]254      int originalSize = tree.Size;
[4191]255      double originalQuality = evaluator.Evaluate(interpreter, tree,
256        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
[3874]257
[4028]258      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
[3874]259
[4028]260      // tree for branch evaluation
261      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
262      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
[3874]263
[4028]264      SymbolicExpressionTree prunedTree = tree;
265      for (int iteration = 0; iteration < iterations; iteration++) {
266        SymbolicExpressionTree iterationBestTree = prunedTree;
267        double bestGain = double.PositiveInfinity;
268        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
[3874]269
[4028]270        for (int i = 0; i < tournamentSize; i++) {
271          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
272          int clonedTreeSize = clonedTree.Size;
273          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
274                             from subTree in node.SubTrees
275                             let subTreeSize = subTree.GetSize()
276                             where subTreeSize <= maxPrunedBranchSize
277                             where clonedTreeSize - subTreeSize >= minPrunedSize
278                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
279                 .ToList();
280          if (prunePoints.Count > 0) {
281            var selectedPrunePoint = prunePoints.SelectRandom(random);
282            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
283            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
284            double branchMean = branchValues.Average();
285            templateTree.Root.SubTrees[0].RemoveSubTree(0);
[3874]286
[4028]287            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
288            var constNode = CreateConstant(branchMean);
289            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
[3874]290
[4191]291            double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
[4034]292        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
[4028]293            double prunedSize = clonedTree.Size;
[4191]294            // deteriation in quality:
295            // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
296            //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
297            //      MSE : minimize: newMse / origMse
298            //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
299            //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
300            //      R²  : minimize: origR² / newR²
301            double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
[4028]302            // size of the pruned tree is always smaller than the size of the original tree
303            // same change in quality => prefer pruning operation that removes a larger tree
[4191]304            double gain = (qualityDeteriation * qualityGainWeight) /
[4028]305                           (originalSize / prunedSize);
306            if (gain < bestGain) {
307              bestGain = gain;
308              iterationBestTree = clonedTree;
309            }
310          }
[3874]311        }
[4028]312        prunedTree = iterationBestTree;
[3874]313      }
[4028]314      tree.Root = prunedTree.Root;
[3874]315    }
316
[4028]317    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
318      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
[3874]319      node.Value = constantValue;
320      return node;
321    }
322  }
323}
Note: See TracBrowser for help on using the repository browser.