Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
08/26/10 12:46:41 (14 years ago)
Author:
gkronber
Message:

Overhauled pruning operator. #1142

File:
1 edited

Legend:

Unmodified
Added
Removed
  • branches/DataAnalysis/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs

    r4297 r4328  
    3232using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    3333using System;
     34using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
    3435
    3536namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
     
    272273      double maxPruningRatio, double qualityGainWeight) {
    273274
     275      int originalSize = tree.Size;
     276
     277      // min size of the resulting pruned tree
     278      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
     279
     280      // use the same subset of rows for all iterations and for all pruning tournaments
    274281      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(samplesStart, samplesEnd, (int)Math.Ceiling((samplesEnd - samplesStart) * relativeNumberOfEvaluatedRows));
    275       int originalSize = tree.Size;
    276 
    277       int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
    278       // tree for branch evaluation
    279       SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
    280       while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
    281 
    282282      SymbolicExpressionTree prunedTree = tree;
    283       double currentQuality = quality.Value;
    284283      for (int iteration = 0; iteration < iterations; iteration++) {
    285         SymbolicExpressionTree iterationBestTree = prunedTree;
    286         double bestGain = double.PositiveInfinity;
    287         int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
    288 
    289         for (int i = 0; i < tournamentSize; i++) {
    290           var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
    291           int clonedTreeSize = clonedTree.Size;
    292           var prunePoints = (from node in clonedTree.Root.SubTrees[0].IterateNodesPostfix()
    293                              from subTree in node.SubTrees
    294                              let subTreeSize = subTree.GetSize()
    295                              where subTreeSize <= maxPrunedBranchSize
    296                              where clonedTreeSize - subTreeSize >= minPrunedSize
    297                              select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
    298                  .ToList();
    299           if (prunePoints.Count > 0) {
    300             var selectedPrunePoint = prunePoints.SelectRandom(random);
    301             templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
    302             IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
    303             double branchMean = branchValues.Average();
    304             templateTree.Root.SubTrees[0].RemoveSubTree(0);
    305 
    306             selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
    307             var constNode = CreateConstant(branchMean);
    308             selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
    309 
    310             double prunedQuality = evaluator.Evaluate(interpreter, clonedTree,
    311         lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, rows);
    312             double prunedSize = clonedTree.Size;
    313             // deteriation in quality:
    314             // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
    315             //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
    316             //      MSE : minimize: newMse / origMse
    317             //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
    318             //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
    319             //      R²  : minimize: origR² / newR²
    320             double qualityDeteriation = maximization ? quality.Value / prunedQuality : prunedQuality / quality.Value;
    321             // size of the pruned tree is always smaller than the size of the original tree
    322             // same change in quality => prefer pruning operation that removes a larger tree
    323             double gain = (qualityDeteriation * qualityGainWeight) /
    324                            (originalSize / prunedSize);
    325             if (gain < bestGain) {
    326               bestGain = gain;
    327               iterationBestTree = clonedTree;
    328               currentQuality = prunedQuality;
    329             }
     284        // maximally prune a branch such that the resulting tree size is not smaller than (1-maxPruningRatio) of the original tree
     285        int maxPrunedBranchSize = tree.Size - minPrunedSize;
     286        if (maxPrunedBranchSize > 0) {
     287          PruneTournament(prunedTree, quality, random, tournamentSize, maxPrunedBranchSize, maximization, qualityGainWeight, evaluator, interpreter, problemData.Dataset, problemData.TargetVariable.Value, rows, lowerEstimationLimit, upperEstimationLimit);
     288        }
     289      }
     290    }
     291
     292    private class PruningPoint {
     293      public SymbolicExpressionTreeNode Parent { get; private set; }
     294      public SymbolicExpressionTreeNode Branch { get; private set; }
     295      public int SubTreeIndex { get; private set; }
     296      public PruningPoint(SymbolicExpressionTreeNode parent, SymbolicExpressionTreeNode branch, int index) {
     297        Parent = parent;
     298        Branch = branch;
     299        SubTreeIndex = index;
     300      }
     301    }
     302
     303    private static void PruneTournament(SymbolicExpressionTree tree, DoubleValue quality, IRandom random, int tournamentSize,
     304      int maxPrunedBranchSize, bool maximization, double qualityGainWeight, ISymbolicRegressionEvaluator evaluator, ISymbolicExpressionTreeInterpreter interpreter,
     305      Dataset ds, string targetVariable, IEnumerable<int> rows, double lowerEstimationLimit, double upperEstimationLimit) {
     306      // make a clone for pruningEvaluation
     307      SymbolicExpressionTree pruningEvaluationTree = (SymbolicExpressionTree)tree.Clone();
     308      var prunePoints = (from node in pruningEvaluationTree.Root.SubTrees[0].IterateNodesPostfix()
     309                         from subTree in node.SubTrees
     310                         let subTreeSize = subTree.GetSize()
     311                         where subTreeSize <= maxPrunedBranchSize
     312                         where !(subTree.Symbol is Constant)
     313                         select new PruningPoint(node, subTree, node.SubTrees.IndexOf(subTree)))
     314         .ToList();
     315      double originalQuality = quality.Value;
     316      double originalSize = tree.Size;
     317      if (prunePoints.Count > 0) {
     318        double bestCoeff = double.PositiveInfinity;
     319        List<PruningPoint> tournamentGroup;
     320        if (prunePoints.Count > tournamentSize) {
     321          tournamentGroup = new List<PruningPoint>();
     322          for (int i = 0; i < tournamentSize; i++) {
     323            tournamentGroup.Add(prunePoints.SelectRandom(random));
    330324          }
     325        } else {
     326          tournamentGroup = prunePoints;
    331327        }
    332         prunedTree = iterationBestTree;
    333       }
    334 
    335       quality.Value = currentQuality;
    336       tree.Root = prunedTree.Root;
     328        foreach (PruningPoint prunePoint in tournamentGroup) {
     329          double replacementValue = CalculateReplacementValue(prunePoint.Branch, interpreter, ds, rows);
     330
     331          // temporarily replace the branch with a constant
     332          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
     333          var constNode = CreateConstant(replacementValue);
     334          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, constNode);
     335
     336          // evaluate the pruned tree
     337          double prunedQuality = evaluator.Evaluate(interpreter, pruningEvaluationTree,
     338  lowerEstimationLimit, upperEstimationLimit, ds, targetVariable, rows);
     339
     340          double prunedSize = originalSize - prunePoint.Branch.GetSize() + 1;
     341
     342          double coeff = CalculatePruningCoefficient(maximization, qualityGainWeight, originalQuality, originalSize, prunedQuality, prunedSize);
     343          if (coeff < bestCoeff) {
     344            bestCoeff = coeff;
     345            // clone the currently pruned tree
     346            SymbolicExpressionTree bestTree = (SymbolicExpressionTree)pruningEvaluationTree.Clone();
     347
     348            // and update original tree and quality
     349            tree.Root = bestTree.Root;
     350            quality.Value = prunedQuality;
     351          }
     352
     353          // restore tree that is used for pruning evaluation
     354          prunePoint.Parent.RemoveSubTree(prunePoint.SubTreeIndex);
     355          prunePoint.Parent.InsertSubTree(prunePoint.SubTreeIndex, prunePoint.Branch);
     356        }
     357      }
     358    }
     359
     360    private static double CalculatePruningCoefficient(bool maximization, double qualityGainWeight, double originalQuality, double originalSize, double prunedQuality, double prunedSize) {
     361      // deteriation in quality:
     362      // exp: MSE : newMse < origMse (improvement) => prefer the larger improvement
     363      //      MSE : newMse > origMse (deteriation) => prefer the smaller deteriation
     364      //      MSE : minimize: newMse / origMse
     365      //      R²  : newR² > origR²   (improvment) => prefer the larger improvment
     366      //      R²  : newR² < origR²   (deteriation) => prefer smaller deteriation
     367      //      R²  : minimize: origR² / newR²
     368      double qualityDeteriation = maximization ? originalQuality / prunedQuality : prunedQuality / originalQuality;
     369      // size of the pruned tree is always smaller than the size of the original tree
     370      // same change in quality => prefer pruning operation that removes a larger tree
     371      return (qualityDeteriation * qualityGainWeight) / (originalSize / prunedSize);
     372    }
     373
     374    private static double CalculateReplacementValue(SymbolicExpressionTreeNode branch, ISymbolicExpressionTreeInterpreter interpreter, Dataset ds, IEnumerable<int> rows) {
     375      SymbolicExpressionTreeNode start = (new StartSymbol()).CreateTreeNode();
     376      start.AddSubTree(branch);
     377      SymbolicExpressionTreeNode root = (new ProgramRootSymbol()).CreateTreeNode();
     378      root.AddSubTree(start);
     379      SymbolicExpressionTree tree = new SymbolicExpressionTree(root);
     380      IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(tree, ds, rows);
     381      return branchValues.Average();
    337382    }
    338383
Note: See TracChangeset for help on using the changeset viewer.