Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3.2/sources/HeuristicLab.GP.StructureIdentification/3.3/TournamentPruning.cs @ 4682

Last change on this file since 4682 was 3874, checked in by gkronber, 15 years ago

Added first implementation of a simple randomized greedy pruning operator. #125 (Pruning operator)

File size: 6.9 KB
RevLine 
[3874]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Core;
25using HeuristicLab.Data;
26using HeuristicLab.GP.Interfaces;
27using System;
28using HeuristicLab.DataAnalysis;
29using HeuristicLab.Modeling;
30
31namespace HeuristicLab.GP.StructureIdentification {
32  public class TournamentPruning : OperatorBase {
33    public TournamentPruning()
34      : base() {
35      AddVariableInfo(new VariableInfo("Random", "", typeof(IRandom), VariableKind.In));
36      AddVariableInfo(new VariableInfo("FunctionTree", "The tree to analyse", typeof(IGeneticProgrammingModel), VariableKind.In));
37      AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In));
38      AddVariableInfo(new VariableInfo("TargetVariable", "", typeof(StringData), VariableKind.In));
39      AddVariableInfo(new VariableInfo("TrainingSamplesStart", "Samples start", typeof(IntData), VariableKind.In));
40      AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "Samples end", typeof(IntData), VariableKind.In));
41      AddVariableInfo(new VariableInfo("TreeEvaluator", "", typeof(ITreeEvaluator), VariableKind.In));
42      AddVariableInfo(new VariableInfo("MaxPruningRatio", "Maximale relative size of the pruned branch", typeof(DoubleData), VariableKind.In));
43      AddVariableInfo(new VariableInfo("TournamentSize", "Number of branches to compare for pruning", typeof(IntData), VariableKind.
44In));
45      AddVariableInfo(new VariableInfo("PopulationPercentileStart", "", typeof(DoubleData), VariableKind.In));
46      AddVariableInfo(new VariableInfo("PopulationPercentileEnd", "", typeof(DoubleData), VariableKind.In));
47      AddVariableInfo(new VariableInfo("QualityGainWeight", "", typeof(DoubleData), VariableKind.In));
48    }
49
50    public override IOperation Apply(IScope scope) {
51      IRandom random = scope.GetVariableValue<IRandom>("Random", true);
52      double percentileStart = scope.GetVariableValue<DoubleData>("PopulationPercentileStart", true).Data;
53      double percentileEnd = scope.GetVariableValue<DoubleData>("PopulationPercentileEnd", true).Data;
54      int tournamentSize = scope.GetVariableValue<IntData>("TournamentSize", true).Data;
55      Dataset dataset = scope.GetVariableValue<Dataset>("Dataset", true);
56      string targetVariable = scope.GetVariableValue<StringData>("TargetVariable", true).Data;
57      int samplesStart = scope.GetVariableValue<IntData>("TrainingSamplesStart", true).Data;
58      int samplesEnd = scope.GetVariableValue<IntData>("TrainingSamplesEnd", true).Data;
59      ITreeEvaluator evaluator = scope.GetVariableValue<ITreeEvaluator>("TreeEvaluator", true);
60      double maxPruningRatio = scope.GetVariableValue<DoubleData>("MaxPruningRatio", true).Data;
61      double qualityGainWeight = scope.GetVariableValue<DoubleData>("QualityGainWeight", true).Data;
62      int n = scope.SubScopes.Count;
63      // for each tree in the given percentile
64      var trees = (from subScope in scope.SubScopes
65                   select subScope.GetVariableValue<IGeneticProgrammingModel>("FunctionTree", false))
66                  .Skip((int)(n * percentileStart))
67                  .Take((int)(n * (percentileEnd - percentileStart)));
68      foreach (var tree in trees) {
69        tree.FunctionTree = Prune(random, tree.FunctionTree, tournamentSize, dataset, targetVariable, samplesStart, samplesEnd, evaluator, maxPruningRatio, qualityGainWeight);
70      }
71      return null;
72    }
73
74    public static IFunctionTree Prune(IRandom random, IFunctionTree tree, int tournamentSize,
75      Dataset dataset, string targetVariable, int samplesStart, int samplesEnd, ITreeEvaluator evaluator,
76      double maxPruningRatio, double qualityGainWeight) {
77      var evaluatedRows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
78      var estimatedValues = evaluator.Evaluate(dataset, tree, evaluatedRows).ToArray();
79      var targetValues = dataset.GetVariableValues(targetVariable, samplesStart, samplesEnd);
80      int originalSize = tree.GetSize();
81      double originalMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues));
82
83      int maxPrunedBranchSize = (int)(tree.GetSize() * maxPruningRatio);
84
85
86      IFunctionTree bestTree = tree;
87      double bestGain = double.PositiveInfinity;
88
89      for (int i = 0; i < tournamentSize; i++) {
90        var clonedTree = (IFunctionTree)tree.Clone();
91        var prunePoints = (from node in FunctionTreeIterator.IteratePrefix(clonedTree)
92                           from subTree in node.SubTrees
93                           where subTree.GetSize() <= maxPrunedBranchSize
94                           select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
95               .ToList();
96
97        var selectedPrunePoint = prunePoints[random.Next(prunePoints.Count)];
98        var branchValues = evaluator.Evaluate(dataset, selectedPrunePoint.Branch, evaluatedRows);
99        var branchMean = branchValues.Average();
100
101        selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
102        var constNode = CreateConstant(branchMean);
103        selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
104
105        estimatedValues = evaluator.Evaluate(dataset, clonedTree, evaluatedRows).ToArray();
106        double prunedMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues));
107        double prunedSize = clonedTree.GetSize();
108        // MSE of the pruned tree is larger than the original tree in most cases
109        // size of the pruned tree is always smaller than the size of the original tree
110        // same change in quality => prefer pruning operation that removes a larger tree
111        double gain = ((prunedMse / originalMse) * qualityGainWeight) /
112                       (originalSize / prunedSize);
113        if (gain < bestGain) {
114          bestGain = gain;
115          bestTree = clonedTree;
116        }
117      }
118
119      return bestTree;
120    }
121
122    private static FunctionTree CreateConstant(double constantValue) {
123      var node = (ConstantFunctionTree)(new Constant()).GetTreeNode();
124      node.Value = constantValue;
125      return node;
126    }
127  }
128}
Note: See TracBrowser for help on using the repository browser.