Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs @ 4034

Last change on this file since 4034 was 4034, checked in by mkommend, 14 years ago

implemented first version of partially evaluation of samples (ticket #1082)

File size: 16.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using System;
27using HeuristicLab.Operators;
28using HeuristicLab.Parameters;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Problems.DataAnalysis.Symbolic;
31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
33using HeuristicLab.Optimization;
34
35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
37    private const string RandomParameterName = "Random";
38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
40    private const string SamplesStartParameterName = "SamplesStart";
41    private const string SamplesEndParameterName = "SamplesEnd";
42    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
43    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
44    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
45    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
46    private const string TournamentSizeParameterName = "TournamentSize";
47    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
48    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
49    private const string QualityGainWeightParameterName = "QualityGainWeight";
50    private const string IterationsParameterName = "Iterations";
51    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
52    private const string PruningFrequencyParameterName = "PruningFrequency";
53    private const string GenerationParameterName = "Generations";
54    private const string ResultsParameterName = "Results";
55
56    #region parameter properties
57    public ILookupParameter<IRandom> RandomParameter {
58      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
59    }
60    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
61      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
62    }
63    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
64      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
65    }
66    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
67      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
68    }
69    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
70      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
71    }
72    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
73      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
74    }
75    public IValueLookupParameter<IntValue> SamplesStartParameter {
76      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
77    }
78    public IValueLookupParameter<IntValue> SamplesEndParameter {
79      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
80    }
81    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
82      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
83    }
84    public IValueLookupParameter<IntValue> TournamentSizeParameter {
85      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
86    }
87    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
88      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
89    }
90    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
91      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
92    }
93    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
94      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
95    }
96    public IValueLookupParameter<IntValue> IterationsParameter {
97      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
98    }
99    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
100      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
101    }
102    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
103      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
104    }
105    public ILookupParameter<IntValue> GenerationParameter {
106      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
107    }
108    public ILookupParameter<ResultCollection> ResultsParameter {
109      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
110    }
111    #endregion
112    #region properties
113    public IRandom Random {
114      get { return RandomParameter.ActualValue; }
115    }
116    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
117      get { return SymbolicExpressionTreeParameter.ActualValue; }
118    }
119    public DataAnalysisProblemData DataAnalysisProblemData {
120      get { return DataAnalysisProblemDataParameter.ActualValue; }
121    }
122    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
123      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
124    }
125    public DoubleValue UpperEstimationLimit {
126      get { return UpperEstimationLimitParameter.ActualValue; }
127    }
128    public DoubleValue LowerEstimationLimit {
129      get { return LowerEstimationLimitParameter.ActualValue; }
130    }
131    public IntValue SamplesStart {
132      get { return SamplesStartParameter.ActualValue; }
133    }
134    public IntValue SamplesEnd {
135      get { return SamplesEndParameter.ActualValue; }
136    }
137    public DoubleValue MaxPruningRatio {
138      get { return MaxPruningRatioParameter.ActualValue; }
139    }
140    public IntValue TournamentSize {
141      get { return TournamentSizeParameter.ActualValue; }
142    }
143    public DoubleValue PopulationPercentileStart {
144      get { return PopulationPercentileStartParameter.ActualValue; }
145    }
146    public DoubleValue PopulationPercentileEnd {
147      get { return PopulationPercentileEndParameter.ActualValue; }
148    }
149    public DoubleValue QualityGainWeight {
150      get { return QualityGainWeightParameter.ActualValue; }
151    }
152    public IntValue Iterations {
153      get { return IterationsParameter.ActualValue; }
154    }
155    public IntValue PruningFrequency {
156      get { return PruningFrequencyParameter.ActualValue; }
157    }
158    public IntValue FirstPruningGeneration {
159      get { return FirstPruningGenerationParameter.ActualValue; }
160    }
161    public IntValue Generation {
162      get { return GenerationParameter.ActualValue; }
163    }
164    #endregion
165    public SymbolicRegressionTournamentPruning()
166      : base() {
167      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
168      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
169      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
170      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
171      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
172      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
173      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
174      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
175      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
176      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
177      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
178      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
179      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
180      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
181      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
182      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
183      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
184      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
185    }
186
187    public override IOperation Apply() {
188      bool pruningCondition =
189        (Generation.Value >= FirstPruningGeneration.Value) &&
190        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
191      if (pruningCondition) {
192        int n = SymbolicExpressionTree.Length;
193        double percentileStart = PopulationPercentileStart.Value;
194        double percentileEnd = PopulationPercentileEnd.Value;
195        // for each tree in the given percentile
196        var trees = SymbolicExpressionTree
197          .Skip((int)(n * percentileStart))
198          .Take((int)(n * (percentileEnd - percentileStart)));
199        foreach (var tree in trees) {
200          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
201            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
202            SymbolicExpressionTreeInterpreter,
203            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
204            MaxPruningRatio.Value, QualityGainWeight.Value);
205        }
206      }
207      return base.Apply();
208    }
209
210    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
211      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
212      ISymbolicExpressionTreeInterpreter interpreter,
213      double lowerEstimationLimit, double upperEstimationLimit,
214      double maxPruningRatio, double qualityGainWeight) {
215      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
216      int originalSize = tree.Size;
217      double originalMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, tree,
218        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
219
220      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
221
222      // tree for branch evaluation
223      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
224      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
225
226      SymbolicExpressionTree prunedTree = tree;
227      for (int iteration = 0; iteration < iterations; iteration++) {
228        SymbolicExpressionTree iterationBestTree = prunedTree;
229        double bestGain = double.PositiveInfinity;
230        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
231
232        for (int i = 0; i < tournamentSize; i++) {
233          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
234          int clonedTreeSize = clonedTree.Size;
235          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
236                             from subTree in node.SubTrees
237                             let subTreeSize = subTree.GetSize()
238                             where subTreeSize <= maxPrunedBranchSize
239                             where clonedTreeSize - subTreeSize >= minPrunedSize
240                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
241                 .ToList();
242          if (prunePoints.Count > 0) {
243            var selectedPrunePoint = prunePoints.SelectRandom(random);
244            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
245            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
246            double branchMean = branchValues.Average();
247            templateTree.Root.SubTrees[0].RemoveSubTree(0);
248
249            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
250            var constNode = CreateConstant(branchMean);
251            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
252
253            double prunedMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, clonedTree,
254        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, Enumerable.Range(samplesStart, samplesEnd - samplesStart));
255            double prunedSize = clonedTree.Size;
256            // MSE of the pruned tree is larger than the original tree in most cases
257            // size of the pruned tree is always smaller than the size of the original tree
258            // same change in quality => prefer pruning operation that removes a larger tree
259            double gain = ((prunedMse / originalMse) * qualityGainWeight) /
260                           (originalSize / prunedSize);
261            if (gain < bestGain) {
262              bestGain = gain;
263              iterationBestTree = clonedTree;
264            }
265          }
266        }
267        prunedTree = iterationBestTree;
268      }
269      tree.Root = prunedTree.Root;
270    }
271
272    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
273      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
274      node.Value = constantValue;
275      return node;
276    }
277  }
278}
Note: See TracBrowser for help on using the repository browser.