Changeset 4328
- Timestamp:
- 08/26/10 12:46:41 (14 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs
r4297 r4328 32 32 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; 33 33 using System; 34 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols; 34 35 35 36 namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers { … … 272 273 double maxPruningRatio, double qualityGainWeight) { 273 274 275 int originalSize = tree.Size; 276 277 // min size of the resulting pruned tree 278 int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio)); 279 280 // use the same subset of rows for all iterations and for all pruning tournaments 274 281 IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows)); 275 int originalSize = tree.Size;276 277 int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));278 // tree for branch evaluation279 SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();280 while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);281 282 282 SymbolicExpressionTree prunedTree = tree; 283 double currentQuality = quality.Value;284 283 for (int iteration = 0; iteration < iterations; iteration++) { 285 SymbolicExpressionTree iterationBestTree = prunedTree; 286 double bestGain = double.PositiveInfinity; 287 int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio); 288 289 for (int i = 0; i < tournamentSize; i++) { 290 var clonedTree = (SymbolicExpressionTree)prunedTree.Clone(); 291 int clonedTreeSize = clonedTree.Size; 292 var prunePoints = (from node in clonedTree.Root.SubTrees[0].IterateNodesPostfix() 293 from subTree in node.SubTrees 294 let subTreeSize = subTree.GetSize() 295 where subTreeSize <= maxPrunedBranchSize 296 where clonedTreeSize - subTreeSize >= minPrunedSize 297 select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) }) 298 .ToList(); 299 if (prunePoints.Count > 0) { 300 var selectedPrunePoint = prunePoints.SelectRandom(random); 301 templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch); 302 IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows); 303 double branchMean = branchValues.Average(); 304 templateTree.Root.SubTrees[0].RemoveSubTree(0); 305 306 selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex); 307 var constNode = CreateConstant(branchMean); 308 selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode); 309 310 double prunedQuality = evaluator.Evaluate(interpreter, clonedTree, 311 lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows); 312 double prunedSize = clonedTree.Size; 313 // deteriation in quality: 314 // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement 315 // MSE : newMse > origMse (deteriation) => prefer the smaller deteriation 316 // MSE : minimize: newMse / origMse 317 // R² : newR² > origR² (improvment) => prefer the larger improvment 318 // R² : newR² < origR² (deteriation) => prefer smaller deteriation 319 // R² : minimize: origR² / newR² 320 double qualityDeteriation = maximization ? quality.Value / prunedQuality : prunedQuality / quality.Value; 321 // size of the pruned tree is always smaller than the size of the original tree 322 // same change in quality => prefer pruning operation that removes a larger tree 323 double gain = (qualityDeteriation * qualityGainWeight) / 324 (originalSize / prunedSize); 325 if (gain < bestGain) { 326 bestGain = gain; 327 iterationBestTree = clonedTree; 328 currentQuality = prunedQuality; 329 } 284 // maximally prune a branch such that the resulting tree size is not smaller than (1-maxPruningRatio) of the original tree 285 int maxPrunedBranchSize = tree.Size - minPrunedSize; 286 if (maxPrunedBranchSize > 0) { 287 PruneTournament(prunedTree, quality, random, tournamentSize, maxPrunedBranchSize, maximization, qualityGainWeight, evaluator, interpreter, problemData.Dataset, problemData.TargetVariable.Value, rows, lowerEstimationLimit, upperEstimationLimit); 288 } 289 } 290 } 291 292 private class PruningPoint { 293 public SymbolicExpressionTreeNode Parent { get; private set; } 294 public SymbolicExpressionTreeNode Branch { get; private set; } 295 public int SubTreeIndex { get; private set; } 296 public PruningPoint(SymbolicExpressionTreeNode parent, SymbolicExpressionTreeNode branch, int index) { 297 Parent = parent; 298 Branch = branch; 299 SubTreeIndex = index; 300 } 301 } 302 303 private static void PruneTournament(SymbolicExpressionTree tree, DoubleValue quality, IRandom random, int tournamentSize, 304 int maxPrunedBranchSize, bool maximization, double qualityGainWeight, ISymbolicRegressionEvaluator evaluator, ISymbolicExpressionTreeInterpreter interpreter, 305 Dataset ds, string targetVariable, IEnumerable<int> rows, double lowerEstimationLimit, double upperEstimationLimit) { 306 // make a clone for pruningEvaluation 307 SymbolicExpressionTree pruningEvaluationTree = (SymbolicExpressionTree)tree.Clone(); 308 var prunePoints = (from node in pruningEvaluationTree.Root.SubTrees[0].IterateNodesPostfix() 309 from subTree in node.SubTrees 310 let subTreeSize = subTree.GetSize() 311 where subTreeSize <= maxPrunedBranchSize 312 where !(subTree.Symbol is Constant) 313 select new PruningPoint(node, subTree, node.SubTrees.IndexOf(subTree))) 314 .ToList(); 315 double originalQuality = quality.Value; 316 double originalSize = tree.Size; 317 if (prunePoints.Count > 0) { 318 double bestCoeff = double.PositiveInfinity; 319 List<PruningPoint> tournamentGroup; 320 if (prunePoints.Count > tournamentSize) { 321 tournamentGroup = new List<PruningPoint>(); 322 for (int i = 0; i < tournamentSize; i++) { 323 tournamentGroup.Add(prunePoints.SelectRandom(random)); 330 324 } 325 } else { 326 tournamentGroup = prunePoints; 331 327 } 332 prunedTree = iterationBestTree; 333 } 334 335 quality.Value = currentQuality; 336 tree.Root = prunedTree.Root; 328 foreach (PruningPoint prunePoint in tournamentGroup) { 329 double replacementValue = CalculateReplacementValue(prunePoint.Branch, interpreter, ds, rows); 330 331 // temporarily replace the branch with a constant 332 prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex); 333 var constNode = CreateConstant(replacementValue); 334 prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, constNode); 335 336 // evaluate the pruned tree 337 double prunedQuality = evaluator.Evaluate(interpreter, pruningEvaluationTree, 338 lowerEstimationLimit, upperEstimationLimit, ds, targetVariable, rows); 339 340 double prunedSize = originalSize - prunePoint.Branch.GetSize() + 1; 341 342 double coeff = CalculatePruningCoefficient(maximization, qualityGainWeight, originalQuality, originalSize, prunedQuality, prunedSize); 343 if (coeff < bestCoeff) { 344 bestCoeff = coeff; 345 // clone the currently pruned tree 346 SymbolicExpressionTree bestTree = (SymbolicExpressionTree)pruningEvaluationTree.Clone(); 347 348 // and update original tree and quality 349 tree.Root = bestTree.Root; 350 quality.Value = prunedQuality; 351 } 352 353 // restore tree that is used for pruning evaluation 354 prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex); 355 prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, prunePoint.Branch); 356 } 357 } 358 } 359 360 private static double CalculatePruningCoefficient(bool maximization, double qualityGainWeight, double originalQuality, double originalSize, double prunedQuality, double prunedSize) { 361 // deteriation in quality: 362 // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement 363 // MSE : newMse > origMse (deteriation) => prefer the smaller deteriation 364 // MSE : minimize: newMse / origMse 365 // R² : newR² > origR² (improvment) => prefer the larger improvment 366 // R² : newR² < origR² (deteriation) => prefer smaller deteriation 367 // R² : minimize: origR² / newR² 368 double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality; 369 // size of the pruned tree is always smaller than the size of the original tree 370 // same change in quality => prefer pruning operation that removes a larger tree 371 return (qualityDeteriation * qualityGainWeight) / (originalSize / prunedSize); 372 } 373 374 private static double CalculateReplacementValue(SymbolicExpressionTreeNode branch, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, IEnumerable<int> rows) { 375 SymbolicExpressionTreeNode start = (new StartSymbol()).CreateTreeNode(); 376 start.AddSubTree(branch); 377 SymbolicExpressionTreeNode root = (new ProgramRootSymbol()).CreateTreeNode(); 378 root.AddSubTree(start); 379 SymbolicExpressionTree tree = new SymbolicExpressionTree(root); 380 IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(tree, ds, rows); 381 return branchValues.Average(); 337 382 } 338 383
Note: See TracChangeset
for help on using the changeset viewer.