Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/13/10 11:48:24 (14 years ago)
Author:
gkronber
Message:

Ported pruning operator for symbolic regression solutions from version 3.2 to version 3.3. #125

Location:
trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers
Files:
1 deleted
1 copied

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Regression/3.3/Symbolic/Analyzers/SymbolicRegressionTournamentPruning.cs

    r3901 r4028  
    11#region License Information
    22/* HeuristicLab
    3  * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
     3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
    44 *
    55 * This file is part of HeuristicLab.
     
    2424using HeuristicLab.Core;
    2525using HeuristicLab.Data;
    26 using HeuristicLab.GP.Interfaces;
    2726using System;
    28 using HeuristicLab.DataAnalysis;
    29 using HeuristicLab.Modeling;
    30 
    31 namespace HeuristicLab.GP.StructureIdentification {
    32   public class TournamentPruning : OperatorBase {
    33     public TournamentPruning()
     27using HeuristicLab.Operators;
     28using HeuristicLab.Parameters;
     29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
     30using HeuristicLab.Problems.DataAnalysis.Symbolic;
     31using HeuristicLab.Problems.DataAnalysis.Symbolic.Symbols;
     32using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding.Symbols;
     33using HeuristicLab.Optimization;
     34
     35namespace HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers {
     36  public class SymbolicRegressionTournamentPruning : SingleSuccessorOperator, ISymbolicRegressionAnalyzer {
     37    private const string RandomParameterName = "Random";
     38    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
     39    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
     40    private const string SamplesStartParameterName = "SamplesStart";
     41    private const string SamplesEndParameterName = "SamplesEnd";
     42    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
     43    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
     44    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
     45    private const string MaxPruningRatioParameterName = "MaxPruningRatio";
     46    private const string TournamentSizeParameterName = "TournamentSize";
     47    private const string PopulationPercentileStartParameterName = "PopulationPercentileStart";
     48    private const string PopulationPercentileEndParameterName = "PopulationPercentileEnd";
     49    private const string QualityGainWeightParameterName = "QualityGainWeight";
     50    private const string IterationsParameterName = "Iterations";
     51    private const string FirstPruningGenerationParameterName = "FirstPruningGeneration";
     52    private const string PruningFrequencyParameterName = "PruningFrequency";
     53    private const string GenerationParameterName = "Generations";
     54    private const string ResultsParameterName = "Results";
     55
     56    #region parameter properties
     57    public ILookupParameter<IRandom> RandomParameter {
     58      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
     59    }
     60    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
     61      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
     62    }
     63    public ILookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
     64      get { return (ILookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
     65    }
     66    public ILookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
     67      get { return (ILookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
     68    }
     69    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
     70      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
     71    }
     72    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
     73      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
     74    }
     75    public IValueLookupParameter<IntValue> SamplesStartParameter {
     76      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesStartParameterName]; }
     77    }
     78    public IValueLookupParameter<IntValue> SamplesEndParameter {
     79      get { return (IValueLookupParameter<IntValue>)Parameters[SamplesEndParameterName]; }
     80    }
     81    public IValueLookupParameter<DoubleValue> MaxPruningRatioParameter {
     82      get { return (IValueLookupParameter<DoubleValue>)Parameters[MaxPruningRatioParameterName]; }
     83    }
     84    public IValueLookupParameter<IntValue> TournamentSizeParameter {
     85      get { return (IValueLookupParameter<IntValue>)Parameters[TournamentSizeParameterName]; }
     86    }
     87    public IValueLookupParameter<DoubleValue> PopulationPercentileStartParameter {
     88      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileStartParameterName]; }
     89    }
     90    public IValueLookupParameter<DoubleValue> PopulationPercentileEndParameter {
     91      get { return (IValueLookupParameter<DoubleValue>)Parameters[PopulationPercentileEndParameterName]; }
     92    }
     93    public IValueLookupParameter<DoubleValue> QualityGainWeightParameter {
     94      get { return (IValueLookupParameter<DoubleValue>)Parameters[QualityGainWeightParameterName]; }
     95    }
     96    public IValueLookupParameter<IntValue> IterationsParameter {
     97      get { return (IValueLookupParameter<IntValue>)Parameters[IterationsParameterName]; }
     98    }
     99    public IValueLookupParameter<IntValue> FirstPruningGenerationParameter {
     100      get { return (IValueLookupParameter<IntValue>)Parameters[FirstPruningGenerationParameterName]; }
     101    }
     102    public IValueLookupParameter<IntValue> PruningFrequencyParameter {
     103      get { return (IValueLookupParameter<IntValue>)Parameters[PruningFrequencyParameterName]; }
     104    }
     105    public ILookupParameter<IntValue> GenerationParameter {
     106      get { return (ILookupParameter<IntValue>)Parameters[GenerationParameterName]; }
     107    }
     108    public ILookupParameter<ResultCollection> ResultsParameter {
     109      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
     110    }
     111    #endregion
     112    #region properties
     113    public IRandom Random {
     114      get { return RandomParameter.ActualValue; }
     115    }
     116    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
     117      get { return SymbolicExpressionTreeParameter.ActualValue; }
     118    }
     119    public DataAnalysisProblemData DataAnalysisProblemData {
     120      get { return DataAnalysisProblemDataParameter.ActualValue; }
     121    }
     122    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
     123      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
     124    }
     125    public DoubleValue UpperEstimationLimit {
     126      get { return UpperEstimationLimitParameter.ActualValue; }
     127    }
     128    public DoubleValue LowerEstimationLimit {
     129      get { return LowerEstimationLimitParameter.ActualValue; }
     130    }
     131    public IntValue SamplesStart {
     132      get { return SamplesStartParameter.ActualValue; }
     133    }
     134    public IntValue SamplesEnd {
     135      get { return SamplesEndParameter.ActualValue; }
     136    }
     137    public DoubleValue MaxPruningRatio {
     138      get { return MaxPruningRatioParameter.ActualValue; }
     139    }
     140    public IntValue TournamentSize {
     141      get { return TournamentSizeParameter.ActualValue; }
     142    }
     143    public DoubleValue PopulationPercentileStart {
     144      get { return PopulationPercentileStartParameter.ActualValue; }
     145    }
     146    public DoubleValue PopulationPercentileEnd {
     147      get { return PopulationPercentileEndParameter.ActualValue; }
     148    }
     149    public DoubleValue QualityGainWeight {
     150      get { return QualityGainWeightParameter.ActualValue; }
     151    }
     152    public IntValue Iterations {
     153      get { return IterationsParameter.ActualValue; }
     154    }
     155    public IntValue PruningFrequency {
     156      get { return PruningFrequencyParameter.ActualValue; }
     157    }
     158    public IntValue FirstPruningGeneration {
     159      get { return FirstPruningGenerationParameter.ActualValue; }
     160    }
     161    public IntValue Generation {
     162      get { return GenerationParameter.ActualValue; }
     163    }
     164    #endregion
     165    public SymbolicRegressionTournamentPruning()
    34166      : base() {
    35       AddVariableInfo(new VariableInfo("Random", "", typeof(IRandom), VariableKind.In));
    36       AddVariableInfo(new VariableInfo("FunctionTree", "The tree to analyse", typeof(IGeneticProgrammingModel), VariableKind.In));
    37       AddVariableInfo(new VariableInfo("Dataset", "Dataset", typeof(Dataset), VariableKind.In));
    38       AddVariableInfo(new VariableInfo("TargetVariable", "", typeof(StringData), VariableKind.In));
    39       AddVariableInfo(new VariableInfo("TrainingSamplesStart", "Samples start", typeof(IntData), VariableKind.In));
    40       AddVariableInfo(new VariableInfo("TrainingSamplesEnd", "Samples end", typeof(IntData), VariableKind.In));
    41       AddVariableInfo(new VariableInfo("TreeEvaluator", "", typeof(ITreeEvaluator), VariableKind.In));
    42       AddVariableInfo(new VariableInfo("MaxPruningRatio", "Maximale relative size of the pruned branch", typeof(DoubleData), VariableKind.In));
    43       AddVariableInfo(new VariableInfo("TournamentSize", "Number of branches to compare for pruning", typeof(IntData), VariableKind.
    44 In));
    45       AddVariableInfo(new VariableInfo("PopulationPercentileStart", "", typeof(DoubleData), VariableKind.In));
    46       AddVariableInfo(new VariableInfo("PopulationPercentileEnd", "", typeof(DoubleData), VariableKind.In));
    47       AddVariableInfo(new VariableInfo("QualityGainWeight", "", typeof(DoubleData), VariableKind.In));
    48     }
    49 
    50     public override IOperation Apply(IScope scope) {
    51       IRandom random = scope.GetVariableValue<IRandom>("Random", true);
    52       double percentileStart = scope.GetVariableValue<DoubleData>("PopulationPercentileStart", true).Data;
    53       double percentileEnd = scope.GetVariableValue<DoubleData>("PopulationPercentileEnd", true).Data;
    54       int tournamentSize = scope.GetVariableValue<IntData>("TournamentSize", true).Data;
    55       Dataset dataset = scope.GetVariableValue<Dataset>("Dataset", true);
    56       string targetVariable = scope.GetVariableValue<StringData>("TargetVariable", true).Data;
    57       int samplesStart = scope.GetVariableValue<IntData>("TrainingSamplesStart", true).Data;
    58       int samplesEnd = scope.GetVariableValue<IntData>("TrainingSamplesEnd", true).Data;
    59       ITreeEvaluator evaluator = scope.GetVariableValue<ITreeEvaluator>("TreeEvaluator", true);
    60       double maxPruningRatio = scope.GetVariableValue<DoubleData>("MaxPruningRatio", true).Data;
    61       double qualityGainWeight = scope.GetVariableValue<DoubleData>("QualityGainWeight", true).Data;
    62       int n = scope.SubScopes.Count;
    63       // for each tree in the given percentile
    64       var trees = (from subScope in scope.SubScopes
    65                    select subScope.GetVariableValue<IGeneticProgrammingModel>("FunctionTree", false))
    66                   .Skip((int)(n * percentileStart))
    67                   .Take((int)(n * (percentileEnd - percentileStart)));
    68       foreach (var tree in trees) {
    69         tree.FunctionTree = Prune(random, tree.FunctionTree, tournamentSize, dataset, targetVariable, samplesStart, samplesEnd, evaluator, maxPruningRatio, qualityGainWeight);
    70       }
    71       return null;
    72     }
    73 
    74     public static IFunctionTree Prune(IRandom random, IFunctionTree tree, int tournamentSize,
    75       Dataset dataset, string targetVariable, int samplesStart, int samplesEnd, ITreeEvaluator evaluator,
    76       double maxPruningRatio, double qualityGainWeight) {
    77       var evaluatedRows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
    78       var estimatedValues = evaluator.Evaluate(dataset, tree, evaluatedRows).ToArray();
    79       var targetValues = dataset.GetVariableValues(targetVariable, samplesStart, samplesEnd);
    80       int originalSize = tree.GetSize();
    81       double originalMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues));
    82 
    83       int maxPrunedBranchSize = (int)(tree.GetSize() * maxPruningRatio);
    84 
    85 
    86       IFunctionTree bestTree = tree;
    87       double bestGain = double.PositiveInfinity;
    88 
    89       for (int i = 0; i < tournamentSize; i++) {
    90         var clonedTree = (IFunctionTree)tree.Clone();
    91         var prunePoints = (from node in FunctionTreeIterator.IteratePrefix(clonedTree)
    92                            from subTree in node.SubTrees
    93                            where subTree.GetSize() <= maxPrunedBranchSize
    94                            select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
    95                .ToList();
    96 
    97         var selectedPrunePoint = prunePoints[random.Next(prunePoints.Count)];
    98         var branchValues = evaluator.Evaluate(dataset, selectedPrunePoint.Branch, evaluatedRows);
    99         var branchMean = branchValues.Average();
    100 
    101         selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
    102         var constNode = CreateConstant(branchMean);
    103         selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
    104 
    105         estimatedValues = evaluator.Evaluate(dataset, clonedTree, evaluatedRows).ToArray();
    106         double prunedMse = SimpleMSEEvaluator.Calculate(Matrix<double>.Create(targetValues, estimatedValues));
    107         double prunedSize = clonedTree.GetSize();
    108         // MSE of the pruned tree is larger than the original tree in most cases
    109         // size of the pruned tree is always smaller than the size of the original tree
    110         // same change in quality => prefer pruning operation that removes a larger tree
    111         double gain = ((prunedMse / originalMse) * qualityGainWeight) /
    112                        (originalSize / prunedSize);
    113         if (gain < bestGain) {
    114           bestGain = gain;
    115           bestTree = clonedTree;
     167      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "A random number generator."));
     168      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to prune."));
     169      Parameters.Add(new LookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for branch impact evaluation."));
     170      Parameters.Add(new LookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter to use for node impact evaluation"));
     171      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesStartParameterName, "The first row index of the dataset partition to use for branch impact evaluation."));
     172      Parameters.Add(new ValueLookupParameter<IntValue>(SamplesEndParameterName, "The last row index of the dataset partition to use for branch impact evaluation."));
     173      Parameters.Add(new ValueLookupParameter<DoubleValue>(MaxPruningRatioParameterName, "The maximal relative size of the pruned branch.", new DoubleValue(0.5)));
     174      Parameters.Add(new ValueLookupParameter<IntValue>(TournamentSizeParameterName, "The number of branches to compare for pruning", new IntValue(10)));
     175      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileStartParameterName, "The start of the population percentile to consider for pruning.", new DoubleValue(0.25)));
     176      Parameters.Add(new ValueLookupParameter<DoubleValue>(PopulationPercentileEndParameterName, "The end of the population percentile to consider for pruning.", new DoubleValue(0.75)));
     177      Parameters.Add(new ValueLookupParameter<DoubleValue>(QualityGainWeightParameterName, "The weight of the quality gain relative to the size gain.", new DoubleValue(1.0)));
     178      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit to use for evaluation."));
     179      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit to use for evaluation."));
     180      Parameters.Add(new ValueLookupParameter<IntValue>(IterationsParameterName, "The number of pruning iterations to apply for each tree.", new IntValue(1)));
     181      Parameters.Add(new ValueLookupParameter<IntValue>(FirstPruningGenerationParameterName, "The first generation when pruning should be applied.", new IntValue(1)));
     182      Parameters.Add(new ValueLookupParameter<IntValue>(PruningFrequencyParameterName, "The frequency of pruning operations (1: every generation, 2: every second generation...)", new IntValue(1)));
     183      Parameters.Add(new LookupParameter<IntValue>(GenerationParameterName, "The current generation."));
     184      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The results collection."));
     185    }
     186
     187    public override IOperation Apply() {
     188      bool pruningCondition =
     189        (Generation.Value >= FirstPruningGeneration.Value) &&
     190        ((Generation.Value - FirstPruningGeneration.Value) % PruningFrequency.Value == 0);
     191      if (pruningCondition) {
     192        int n = SymbolicExpressionTree.Length;
     193        double percentileStart = PopulationPercentileStart.Value;
     194        double percentileEnd = PopulationPercentileEnd.Value;
     195        // for each tree in the given percentile
     196        var trees = SymbolicExpressionTree
     197          .Skip((int)(n * percentileStart))
     198          .Take((int)(n * (percentileEnd - percentileStart)));
     199        foreach (var tree in trees) {
     200          Prune(Random, tree, Iterations.Value, TournamentSize.Value,
     201            DataAnalysisProblemData, SamplesStart.Value, SamplesEnd.Value,
     202            SymbolicExpressionTreeInterpreter,
     203            LowerEstimationLimit.Value, UpperEstimationLimit.Value,
     204            MaxPruningRatio.Value, QualityGainWeight.Value);
    116205        }
    117206      }
    118 
    119       return bestTree;
    120     }
    121 
    122     private static FunctionTree CreateConstant(double constantValue) {
    123       var node = (ConstantFunctionTree)(new Constant()).GetTreeNode();
     207      return base.Apply();
     208    }
     209
     210    public static void Prune(IRandom random, SymbolicExpressionTree tree, int iterations, int tournamentSize,
     211      DataAnalysisProblemData problemData, int samplesStart, int samplesEnd,
     212      ISymbolicExpressionTreeInterpreter interpreter,
     213      double lowerEstimationLimit, double upperEstimationLimit,
     214      double maxPruningRatio, double qualityGainWeight) {
     215      IEnumerable<int> rows = Enumerable.Range(samplesStart, samplesEnd - samplesStart);
     216      int originalSize = tree.Size;
     217      double originalMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, tree,
     218        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, samplesStart, samplesEnd);
     219
     220      int minPrunedSize = (int)(originalSize * (1 - maxPruningRatio));
     221
     222      // tree for branch evaluation
     223      SymbolicExpressionTree templateTree = (SymbolicExpressionTree)tree.Clone();
     224      while (templateTree.Root.SubTrees[0].SubTrees.Count > 0) templateTree.Root.SubTrees[0].RemoveSubTree(0);
     225
     226      SymbolicExpressionTree prunedTree = tree;
     227      for (int iteration = 0; iteration < iterations; iteration++) {
     228        SymbolicExpressionTree iterationBestTree = prunedTree;
     229        double bestGain = double.PositiveInfinity;
     230        int maxPrunedBranchSize = (int)(prunedTree.Size * maxPruningRatio);
     231
     232        for (int i = 0; i < tournamentSize; i++) {
     233          var clonedTree = (SymbolicExpressionTree)prunedTree.Clone();
     234          int clonedTreeSize = clonedTree.Size;
     235          var prunePoints = (from node in clonedTree.IterateNodesPostfix()
     236                             from subTree in node.SubTrees
     237                             let subTreeSize = subTree.GetSize()
     238                             where subTreeSize <= maxPrunedBranchSize
     239                             where clonedTreeSize - subTreeSize >= minPrunedSize
     240                             select new { Parent = node, Branch = subTree, SubTreeIndex = node.SubTrees.IndexOf(subTree) })
     241                 .ToList();
     242          if (prunePoints.Count > 0) {
     243            var selectedPrunePoint = prunePoints.SelectRandom(random);
     244            templateTree.Root.SubTrees[0].AddSubTree(selectedPrunePoint.Branch);
     245            IEnumerable<double> branchValues = interpreter.GetSymbolicExpressionTreeValues(templateTree, problemData.Dataset, rows);
     246            double branchMean = branchValues.Average();
     247            templateTree.Root.SubTrees[0].RemoveSubTree(0);
     248
     249            selectedPrunePoint.Parent.RemoveSubTree(selectedPrunePoint.SubTreeIndex);
     250            var constNode = CreateConstant(branchMean);
     251            selectedPrunePoint.Parent.InsertSubTree(selectedPrunePoint.SubTreeIndex, constNode);
     252
     253            double prunedMse = SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(interpreter, clonedTree,
     254        lowerEstimationLimit, upperEstimationLimit, problemData.Dataset, problemData.TargetVariable.Value, samplesStart, samplesEnd);
     255            double prunedSize = clonedTree.Size;
     256            // MSE of the pruned tree is larger than the original tree in most cases
     257            // size of the pruned tree is always smaller than the size of the original tree
     258            // same change in quality => prefer pruning operation that removes a larger tree
     259            double gain = ((prunedMse / originalMse) * qualityGainWeight) /
     260                           (originalSize / prunedSize);
     261            if (gain < bestGain) {
     262              bestGain = gain;
     263              iterationBestTree = clonedTree;
     264            }
     265          }
     266        }
     267        prunedTree = iterationBestTree;
     268      }
     269      tree.Root = prunedTree.Root;
     270    }
     271
     272    private static SymbolicExpressionTreeNode CreateConstant(double constantValue) {
     273      var node = (ConstantTreeNode)(new Constant()).CreateTreeNode();
    124274      node.Value = constantValue;
    125275      return node;
Note: See TracChangeset for help on using the changeset viewer.