Changeset 4028 for trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers
- Timestamp:
- 07/13/10 11:48:24 (14 years ago)
- Location:
- trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers
- Files:
-
- 1 deleted
- 1 copied
Legend:
- Unmodified
- Added
- Removed
-
trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs
r3901 r4028 1 1 #region License Information 2 2 /* HeuristicLab 3 * Copyright (C) 2002-20 08Heuristic and Evolutionary Algorithms Laboratory (HEAL)3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL) 4 4 * 5 5 * This file is part of HeuristicLab. … … 24 24 using HeuristicLab.Core; 25 25 using HeuristicLab.Data; 26 using HeuristicLab.GP.Interfaces;27 26 using System; 28 using HeuristicLab.DataAnalysis; 29 using HeuristicLab.Modeling; 30 31 namespace HeuristicLab.GP.StructureIdentification { 32 public class TournamentPruning : OperatorBase { 33 public TournamentPruning() 27 using HeuristicLab.Operators; 28 using HeuristicLab.Parameters; 29 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 30 using HeuristicLab.Problems.DataAnalysis.Symbolic; 31 using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols; 32 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols; 33 using HeuristicLab.Optimization; 34 35 namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers { 36 public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer { 37 private const string RandomParameterName = "Random"; 38 private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree"; 39 private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData"; 40 private const string SamplesStartParameterName = "SamplesStart"; 41 private const string SamplesEndParameterName = "SamplesEnd"; 42 private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter"; 43 private const string UpperEstimationLimitParameterName = "UpperEstimationLimit"; 44 private const string LowerEstimationLimitParameterName = "LowerEstimationLimit"; 45 private const string MaxPruningRatioParameterName = "MaxPruningRatio"; 46 private const string TournamentSizeParameterName = "TournamentSize"; 47 private const string PopulationPercentileStartParameterName = "PopulationPercentileStart"; 48 private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd"; 49 private const string QualityGainWeightParameterName = "QualityGainWeight"; 50 private const string IterationsParameterName = "Iterations"; 51 private const string FirstPruningGenerationParameterName = "FirstPruningGeneration"; 52 private const string PruningFrequencyParameterName = "PruningFrequency"; 53 private const string GenerationParameterName = "Generations"; 54 private const string ResultsParameterName = "Results"; 55 56 #region parameter properties 57 public ILookupParameter<IRandom> RandomParameter { 58 get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; } 59 } 60 public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter { 61 get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; } 62 } 63 public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter { 64 get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; } 65 } 66 public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter { 67 get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; } 68 } 69 public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter { 70 get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; } 71 } 72 public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter { 73 get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; } 74 } 75 public IValueLookupParameter<IntValue> SamplesStartParameter { 76 get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; } 77 } 78 public IValueLookupParameter<IntValue> SamplesEndParameter { 79 get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; } 80 } 81 public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter { 82 get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; } 83 } 84 public IValueLookupParameter<IntValue> TournamentSizeParameter { 85 get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; } 86 } 87 public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter { 88 get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; } 89 } 90 public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter { 91 get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; } 92 } 93 public IValueLookupParameter<DoubleValue> QualityGainWeightParameter { 94 get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; } 95 } 96 public IValueLookupParameter<IntValue> IterationsParameter { 97 get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; } 98 } 99 public IValueLookupParameter<IntValue> FirstPruningGenerationParameter { 100 get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; } 101 } 102 public IValueLookupParameter<IntValue> PruningFrequencyParameter { 103 get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; } 104 } 105 public ILookupParameter<IntValue> GenerationParameter { 106 get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; } 107 } 108 public ILookupParameter<ResultCollection> ResultsParameter { 109 get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; } 110 } 111 #endregion 112 #region properties 113 public IRandom Random { 114 get { return RandomParameter.ActualValue; } 115 } 116 public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree { 117 get { return SymbolicExpressionTreeParameter.ActualValue; } 118 } 119 public DataAnalysisProblemData DataAnalysisProblemData { 120 get { return DataAnalysisProblemDataParameter.ActualValue; } 121 } 122 public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter { 123 get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; } 124 } 125 public DoubleValue UpperEstimationLimit { 126 get { return UpperEstimationLimitParameter.ActualValue; } 127 } 128 public DoubleValue LowerEstimationLimit { 129 get { return LowerEstimationLimitParameter.ActualValue; } 130 } 131 public IntValue SamplesStart { 132 get { return SamplesStartParameter.ActualValue; } 133 } 134 public IntValue SamplesEnd { 135 get { return SamplesEndParameter.ActualValue; } 136 } 137 public DoubleValue MaxPruningRatio { 138 get { return MaxPruningRatioParameter.ActualValue; } 139 } 140 public IntValue TournamentSize { 141 get { return TournamentSizeParameter.ActualValue; } 142 } 143 public DoubleValue PopulationPercentileStart { 144 get { return PopulationPercentileStartParameter.ActualValue; } 145 } 146 public DoubleValue PopulationPercentileEnd { 147 get { return PopulationPercentileEndParameter.ActualValue; } 148 } 149 public DoubleValue QualityGainWeight { 150 get { return QualityGainWeightParameter.ActualValue; } 151 } 152 public IntValue Iterations { 153 get { return IterationsParameter.ActualValue; } 154 } 155 public IntValue PruningFrequency { 156 get { return PruningFrequencyParameter.ActualValue; } 157 } 158 public IntValue FirstPruningGeneration { 159 get { return FirstPruningGenerationParameter.ActualValue; } 160 } 161 public IntValue Generation { 162 get { return GenerationParameter.ActualValue; } 163 } 164 #endregion 165 public SymbolicRegressionTournamentPruning() 34 166 : base() { 35 AddVariableInfo(new VariableInfo("Random", "", typeof(IRandom), VariableKind.In)); 36 AddVariableInfo(new VariableInfo("FunctionTree", "The tree to analyse", typeof(IGeneticProgrammingModel), VariableKind.In)); 37 AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In)); 38 AddVariableInfo(new VariableInfo("TargetVariable", "", typeof(StringData), VariableKind.In)); 39 AddVariableInfo(new VariableInfo("TrainingSamplesStart", "Samples start", typeof(IntData), VariableKind.In)); 40 AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "Samples end", typeof(IntData), VariableKind.In)); 41 AddVariableInfo(new VariableInfo("TreeEvaluator", "", typeof(ITreeEvaluator), VariableKind.In)); 42 AddVariableInfo(new VariableInfo("MaxPruningRatio", "Maximale relative size of the pruned branch", typeof(DoubleData), VariableKind.In)); 43 AddVariableInfo(new VariableInfo("TournamentSize", "Number of branches to compare for pruning", typeof(IntData), VariableKind. 44 In)); 45 AddVariableInfo(new VariableInfo("PopulationPercentileStart", "", typeof(DoubleData), VariableKind.In)); 46 AddVariableInfo(new VariableInfo("PopulationPercentileEnd", "", typeof(DoubleData), VariableKind.In)); 47 AddVariableInfo(new VariableInfo("QualityGainWeight", "", typeof(DoubleData), VariableKind.In)); 48 } 49 50 public override IOperation Apply(IScope scope) { 51 IRandom random = scope.GetVariableValue<IRandom>("Random", true); 52 double percentileStart = scope.GetVariableValue<DoubleData>("PopulationPercentileStart", true).Data; 53 double percentileEnd = scope.GetVariableValue<DoubleData>("PopulationPercentileEnd", true).Data; 54 int tournamentSize = scope.GetVariableValue<IntData>("TournamentSize", true).Data; 55 Dataset dataset = scope.GetVariableValue<Dataset>("Dataset", true); 56 string targetVariable = scope.GetVariableValue<StringData>("TargetVariable", true).Data; 57 int samplesStart = scope.GetVariableValue<IntData>("TrainingSamplesStart", true).Data; 58 int samplesEnd = scope.GetVariableValue<IntData>("TrainingSamplesEnd", true).Data; 59 ITreeEvaluator evaluator = scope.GetVariableValue<ITreeEvaluator>("TreeEvaluator", true); 60 double maxPruningRatio = scope.GetVariableValue<DoubleData>("MaxPruningRatio", true).Data; 61 double qualityGainWeight = scope.GetVariableValue<DoubleData>("QualityGainWeight", true).Data; 62 int n = scope.SubScopes.Count; 63 // for each tree in the given percentile 64 var trees = (from subScope in scope.SubScopes 65 select subScope.GetVariableValue<IGeneticProgrammingModel>("FunctionTree", false)) 66 .Skip((int)(n * percentileStart)) 67 .Take((int)(n * (percentileEnd - percentileStart))); 68 foreach (var tree in trees) { 69 tree.FunctionTree = Prune(random, tree.FunctionTree, tournamentSize, dataset, targetVariable, samplesStart, samplesEnd, evaluator, maxPruningRatio, qualityGainWeight); 70 } 71 return null; 72 } 73 74 public static IFunctionTree Prune(IRandom random, IFunctionTree tree, int tournamentSize, 75 Dataset dataset, string targetVariable, int samplesStart, int samplesEnd, ITreeEvaluator evaluator, 76 double maxPruningRatio, double qualityGainWeight) { 77 var evaluatedRows = Enumerable.Range(samplesStart, samplesEnd - samplesStart); 78 var estimatedValues = evaluator.Evaluate(dataset, tree, evaluatedRows).ToArray(); 79 var targetValues = dataset.GetVariableValues(targetVariable, samplesStart, samplesEnd); 80 int originalSize = tree.GetSize(); 81 double originalMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues)); 82 83 int maxPrunedBranchSize = (int)(tree.GetSize() * maxPruningRatio); 84 85 86 IFunctionTree bestTree = tree; 87 double bestGain = double.PositiveInfinity; 88 89 for (int i = 0; i < tournamentSize; i++) { 90 var clonedTree = (IFunctionTree)tree.Clone(); 91 var prunePoints = (from node in FunctionTreeIterator.IteratePrefix(clonedTree) 92 from subTree in node.SubTrees 93 where subTree.GetSize() <= maxPrunedBranchSize 94 select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) }) 95 .ToList(); 96 97 var selectedPrunePoint = prunePoints[random.Next(prunePoints.Count)]; 98 var branchValues = evaluator.Evaluate(dataset, selectedPrunePoint.Branch, evaluatedRows); 99 var branchMean = branchValues.Average(); 100 101 selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex); 102 var constNode = CreateConstant(branchMean); 103 selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode); 104 105 estimatedValues = evaluator.Evaluate(dataset, clonedTree, evaluatedRows).ToArray(); 106 double prunedMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues)); 107 double prunedSize = clonedTree.GetSize(); 108 // MSE of the pruned tree is larger than the original tree in most cases 109 // size of the pruned tree is always smaller than the size of the original tree 110 // same change in quality => prefer pruning operation that removes a larger tree 111 double gain = ((prunedMse / originalMse) * qualityGainWeight) / 112 (originalSize / prunedSize); 113 if (gain < bestGain) { 114 bestGain = gain; 115 bestTree = clonedTree; 167 Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator.")); 168 Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune.")); 169 Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation.")); 170 Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation")); 171 Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation.")); 172 Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation.")); 173 Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5))); 174 Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10))); 175 Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25))); 176 Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75))); 177 Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0))); 178 Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation.")); 179 Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation.")); 180 Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1))); 181 Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1))); 182 Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1))); 183 Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation.")); 184 Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection.")); 185 } 186 187 public override IOperation Apply() { 188 bool pruningCondition = 189 (Generation.Value >= FirstPruningGeneration.Value) && 190 ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0); 191 if (pruningCondition) { 192 int n = SymbolicExpressionTree.Length; 193 double percentileStart = PopulationPercentileStart.Value; 194 double percentileEnd = PopulationPercentileEnd.Value; 195 // for each tree in the given percentile 196 var trees = SymbolicExpressionTree 197 .Skip((int)(n * percentileStart)) 198 .Take((int)(n * (percentileEnd - percentileStart))); 199 foreach (var tree in trees) { 200 Prune(Random, tree, Iterations.Value, TournamentSize.Value, 201 DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value, 202 SymbolicExpressionTreeInterpreter, 203 LowerEstimationLimit.Value, UpperEstimationLimit.Value, 204 MaxPruningRatio.Value, QualityGainWeight.Value); 116 205 } 117 206 } 118 119 return bestTree; 120 } 121 122 private static FunctionTree CreateConstant(double constantValue) { 123 var node = (ConstantFunctionTree)(new Constant()).GetTreeNode(); 207 return base.Apply(); 208 } 209 210 public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize, 211 DataAnalysisProblemData problemData, int samplesStart, int samplesEnd, 212 ISymbolicExpressionTreeInterpreter interpreter, 213 double lowerEstimationLimit, double upperEstimationLimit, 214 double maxPruningRatio, double qualityGainWeight) { 215 IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart); 216 int originalSize = tree.Size; 217 double originalMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, tree, 218 lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, samplesStart, samplesEnd); 219 220 int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio)); 221 222 // tree for branch evaluation 223 SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone(); 224 while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0); 225 226 SymbolicExpressionTree prunedTree = tree; 227 for (int iteration = 0; iteration < iterations; iteration++) { 228 SymbolicExpressionTree iterationBestTree = prunedTree; 229 double bestGain = double.PositiveInfinity; 230 int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio); 231 232 for (int i = 0; i < tournamentSize; i++) { 233 var clonedTree = (SymbolicExpressionTree)prunedTree.Clone(); 234 int clonedTreeSize = clonedTree.Size; 235 var prunePoints = (from node in clonedTree.IterateNodesPostfix() 236 from subTree in node.SubTrees 237 let subTreeSize = subTree.GetSize() 238 where subTreeSize <= maxPrunedBranchSize 239 where clonedTreeSize - subTreeSize >= minPrunedSize 240 select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) }) 241 .ToList(); 242 if (prunePoints.Count > 0) { 243 var selectedPrunePoint = prunePoints.SelectRandom(random); 244 templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch); 245 IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows); 246 double branchMean = branchValues.Average(); 247 templateTree.Root.SubTrees[0].RemoveSubTree(0); 248 249 selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex); 250 var constNode = CreateConstant(branchMean); 251 selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode); 252 253 double prunedMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, clonedTree, 254 lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, samplesStart, samplesEnd); 255 double prunedSize = clonedTree.Size; 256 // MSE of the pruned tree is larger than the original tree in most cases 257 // size of the pruned tree is always smaller than the size of the original tree 258 // same change in quality => prefer pruning operation that removes a larger tree 259 double gain = ((prunedMse / originalMse) * qualityGainWeight) / 260 (originalSize / prunedSize); 261 if (gain < bestGain) { 262 bestGain = gain; 263 iterationBestTree = clonedTree; 264 } 265 } 266 } 267 prunedTree = iterationBestTree; 268 } 269 tree.Root = prunedTree.Root; 270 } 271 272 private static SymbolicExpressionTreeNode CreateConstant(double constantValue) { 273 var node = (ConstantTreeNode)(new Constant()).CreateTreeNode(); 124 274 node.Value = constantValue; 125 275 return node;
Note: See TracChangeset
for help on using the changeset viewer.