source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4068

Last change on this file since 4068 was 4068, checked in by swagner, 9 years ago

Sorted usings and removed unused usings in entire solution (#1094)

File size: 15.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
27using HeuristicLab.Operators;
28using HeuristicLab.Optimization;
29using HeuristicLab.Parameters;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32
33namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
34  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
35    private const string RandomParameterName = "Random";
36    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
37    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
38    private const string SamplesStartParameterName = "SamplesStart";
39    private const string SamplesEndParameterName = "SamplesEnd";
40    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
41    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
42    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
43    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
44    private const string TournamentSizeParameterName = "TournamentSize";
45    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
46    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
47    private const string QualityGainWeightParameterName = "QualityGainWeight";
48    private const string IterationsParameterName = "Iterations";
49    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
50    private const string PruningFrequencyParameterName = "PruningFrequency";
51    private const string GenerationParameterName = "Generations";
52    private const string ResultsParameterName = "Results";
53
54    #region parameter properties
55    public ILookupParameter<IRandom> RandomParameter {
56      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
57    }
58    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
59      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
60    }
61    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
62      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
63    }
64    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
65      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
66    }
67    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
68      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
69    }
70    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
71      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
72    }
73    public IValueLookupParameter<IntValue> SamplesStartParameter {
74      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
75    }
76    public IValueLookupParameter<IntValue> SamplesEndParameter {
77      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
78    }
79    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
80      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
81    }
82    public IValueLookupParameter<IntValue> TournamentSizeParameter {
83      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
84    }
85    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
86      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
87    }
88    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
89      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
90    }
91    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
92      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
93    }
94    public IValueLookupParameter<IntValue> IterationsParameter {
95      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
96    }
97    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
98      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
99    }
100    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
101      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
102    }
103    public ILookupParameter<IntValue> GenerationParameter {
104      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
105    }
106    public ILookupParameter<ResultCollection> ResultsParameter {
107      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
108    }
109    #endregion
110    #region properties
111    public IRandom Random {
112      get { return RandomParameter.ActualValue; }
113    }
114    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
115      get { return SymbolicExpressionTreeParameter.ActualValue; }
116    }
117    public DataAnalysisProblemData DataAnalysisProblemData {
118      get { return DataAnalysisProblemDataParameter.ActualValue; }
119    }
120    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
121      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
122    }
123    public DoubleValue UpperEstimationLimit {
124      get { return UpperEstimationLimitParameter.ActualValue; }
125    }
126    public DoubleValue LowerEstimationLimit {
127      get { return LowerEstimationLimitParameter.ActualValue; }
128    }
129    public IntValue SamplesStart {
130      get { return SamplesStartParameter.ActualValue; }
131    }
132    public IntValue SamplesEnd {
133      get { return SamplesEndParameter.ActualValue; }
134    }
135    public DoubleValue MaxPruningRatio {
136      get { return MaxPruningRatioParameter.ActualValue; }
137    }
138    public IntValue TournamentSize {
139      get { return TournamentSizeParameter.ActualValue; }
140    }
141    public DoubleValue PopulationPercentileStart {
142      get { return PopulationPercentileStartParameter.ActualValue; }
143    }
144    public DoubleValue PopulationPercentileEnd {
145      get { return PopulationPercentileEndParameter.ActualValue; }
146    }
147    public DoubleValue QualityGainWeight {
148      get { return QualityGainWeightParameter.ActualValue; }
149    }
150    public IntValue Iterations {
151      get { return IterationsParameter.ActualValue; }
152    }
153    public IntValue PruningFrequency {
154      get { return PruningFrequencyParameter.ActualValue; }
155    }
156    public IntValue FirstPruningGeneration {
157      get { return FirstPruningGenerationParameter.ActualValue; }
158    }
159    public IntValue Generation {
160      get { return GenerationParameter.ActualValue; }
161    }
162    #endregion
163    public SymbolicRegressionTournamentPruning()
164      : base() {
165      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
166      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
167      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
168      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
169      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
170      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
171      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
172      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
173      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
174      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
175      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
176      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
177      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
178      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
179      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
180      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
181      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
182      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
183    }
184
185    public override IOperation Apply() {
186      bool pruningCondition =
187        (Generation.Value >= FirstPruningGeneration.Value) &&
188        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
189      if (pruningCondition) {
190        int n = SymbolicExpressionTree.Length;
191        double percentileStart = PopulationPercentileStart.Value;
192        double percentileEnd = PopulationPercentileEnd.Value;
193        // for each tree in the given percentile
194        var trees = SymbolicExpressionTree
195          .Skip((int)(n * percentileStart))
196          .Take((int)(n * (percentileEnd - percentileStart)));
197        foreach (var tree in trees) {
198          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
199            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
200            SymbolicExpressionTreeInterpreter,
201            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
202            MaxPruningRatio.Value, QualityGainWeight.Value);
203        }
204      }
205      return base.Apply();
206    }
207
208    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
209      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
210      ISymbolicExpressionTreeInterpreter interpreter,
211      double lowerEstimationLimit, double upperEstimationLimit,
212      double maxPruningRatio, double qualityGainWeight) {
213      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
214      int originalSize = tree.Size;
215      double originalMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, tree,
216        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
217
218      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
219
220      // tree for branch evaluation
221      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
222      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
223
224      SymbolicExpressionTree prunedTree = tree;
225      for (int iteration = 0; iteration < iterations; iteration++) {
226        SymbolicExpressionTree iterationBestTree = prunedTree;
227        double bestGain = double.PositiveInfinity;
228        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
229
230        for (int i = 0; i < tournamentSize; i++) {
231          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
232          int clonedTreeSize = clonedTree.Size;
233          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
234                             from subTree in node.SubTrees
235                             let subTreeSize = subTree.GetSize()
236                             where subTreeSize <= maxPrunedBranchSize
237                             where clonedTreeSize - subTreeSize >= minPrunedSize
238                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
239                 .ToList();
240          if (prunePoints.Count > 0) {
241            var selectedPrunePoint = prunePoints.SelectRandom(random);
242            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
243            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
244            double branchMean = branchValues.Average();
245            templateTree.Root.SubTrees[0].RemoveSubTree(0);
246
247            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
248            var constNode = CreateConstant(branchMean);
249            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
250
251            double prunedMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, clonedTree,
252        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
253            double prunedSize = clonedTree.Size;
254            // MSE of the pruned tree is larger than the original tree in most cases
255            // size of the pruned tree is always smaller than the size of the original tree
256            // same change in quality => prefer pruning operation that removes a larger tree
257            double gain = ((prunedMse / originalMse) * qualityGainWeight) /
258                           (originalSize / prunedSize);
259            if (gain < bestGain) {
260              bestGain = gain;
261              iterationBestTree = clonedTree;
262            }
263          }
264        }
265        prunedTree = iterationBestTree;
266      }
267      tree.Root = prunedTree.Root;
268    }
269
270    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
271      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
272      node.Value = constantValue;
273      return node;
274    }
275  }
276}
Note: See TracBrowser for help on using the repository browser.