Changeset 12189


Ignore:
Timestamp:
03/11/15 14:07:50 (6 years ago)
Author:
bburlacu
Message:

#2359: Implemented improvements

Location:
trunk/sources
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningAnalyzer.cs

    r12012 r12189  
    4343    public SymbolicClassificationPruningAnalyzer() {
    4444      Parameters.Add(new ValueParameter<SymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, "The impact values calculator", new SymbolicClassificationSolutionImpactValuesCalculator()));
    45       Parameters.Add(new ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator()));
     45      Parameters.Add(new ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator(new SymbolicClassificationSolutionImpactValuesCalculator())));
    4646    }
    4747  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs

    r12012 r12189  
    2222#endregion
    2323
     24using System.Collections.Generic;
    2425using System.Linq;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
     28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2729using HeuristicLab.Parameters;
    2830using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    3234  [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")]
    3335  public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
    34     private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
    3536    private const string ModelCreatorParameterName = "ModelCreator";
    3637
     
    5253    protected SymbolicClassificationPruningOperator(bool deserializing) : base(deserializing) { }
    5354
    54     public SymbolicClassificationPruningOperator() {
    55       Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, new SymbolicClassificationSolutionImpactValuesCalculator()));
     55    public SymbolicClassificationPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator)
     56      : base(impactValuesCalculator) {
    5657      Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName));
    5758    }
    5859
    59     protected override ISymbolicDataAnalysisModel CreateModel() {
    60       var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper);
    61       var problemData = (IClassificationProblemData)ProblemData;
    62       var rows = problemData.TrainingIndices;
    63       model.RecalculateModelParameters(problemData, rows);
     60    protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
     61      var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     62      var classificationProblemData = (IClassificationProblemData)problemData;
     63      var rows = classificationProblemData.TrainingIndices;
     64      model.RecalculateModelParameters(classificationProblemData, rows);
    6465      return model;
    6566    }
     
    6970      var classificationProblemData = (IClassificationProblemData)ProblemData;
    7071      var trainingIndices = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
    71       var estimatedValues = classificationModel.GetEstimatedClassValues(ProblemData.Dataset, trainingIndices);
    72       var targetValues = ProblemData.Dataset.GetDoubleValues(classificationProblemData.TargetVariable, trainingIndices);
     72
     73      return Evaluate(classificationModel, classificationProblemData, trainingIndices);
     74    }
     75
     76    private static double Evaluate(IClassificationModel model, IClassificationProblemData problemData, IEnumerable<int> rows) {
     77      var estimatedValues = model.GetEstimatedClassValues(problemData.Dataset, rows);
     78      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    7379      OnlineCalculatorError errorState;
    7480      var quality = OnlineAccuracyCalculator.Calculate(targetValues, estimatedValues, out errorState);
     
    7682      return quality;
    7783    }
     84
     85    public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicClassificationModelCreator modelCreator,
     86      SymbolicClassificationSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     87      IClassificationProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows,
     88      double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) {
     89      var clonedTree = (ISymbolicExpressionTree)tree.Clone();
     90      var model = modelCreator.CreateSymbolicClassificationModel(clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     91
     92      var nodes = clonedTree.IterateNodesPrefix().ToList();
     93      double quality = Evaluate(model, problemData, rows);
     94
     95      for (int i = 0; i < nodes.Count; ++i) {
     96        var node = nodes[i];
     97        if (node is ConstantTreeNode) continue;
     98
     99        double impactValue, replacementValue;
     100        impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, quality);
     101
     102        if (pruneOnlyZeroImpactNodes) {
     103          if (!impactValue.IsAlmost(0.0)) continue;
     104        } else if (nodeImpactThreshold < impactValue) {
     105          continue;
     106        }
     107
     108        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
     109        constantNode.Value = replacementValue;
     110
     111        ReplaceWithConstant(node, constantNode);
     112        i += node.GetLength() - 1; // skip subtrees under the node that was folded
     113
     114        quality -= impactValue;
     115      }
     116      return model.SymbolicExpressionTree;
     117    }
    78118  }
    79119}
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionPruningAnalyzer.cs

    r12012 r12189  
    4545    public SymbolicRegressionPruningAnalyzer() {
    4646      Parameters.Add(new ValueParameter<SymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, "The impact values calculator", new SymbolicRegressionSolutionImpactValuesCalculator()));
    47       Parameters.Add(new ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicRegressionPruningOperator()));
     47      Parameters.Add(new ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicRegressionPruningOperator(new SymbolicRegressionSolutionImpactValuesCalculator())));
    4848    }
    4949  }
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionPruningOperator.cs

    r12012 r12189  
    2222#endregion
    2323
     24using System.Collections.Generic;
    2425using System.Linq;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
    27 using HeuristicLab.Parameters;
     28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2829using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
    2930
     
    3233  [Item("SymbolicRegressionPruningOperator", "An operator which prunes symbolic regression trees.")]
    3334  public class SymbolicRegressionPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
    34     private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
    35 
    3635    protected SymbolicRegressionPruningOperator(SymbolicRegressionPruningOperator original, Cloner cloner)
    3736      : base(original, cloner) {
     
    4443    protected SymbolicRegressionPruningOperator(bool deserializing) : base(deserializing) { }
    4544
    46     public SymbolicRegressionPruningOperator() {
    47       var impactValuesCalculator = new SymbolicRegressionSolutionImpactValuesCalculator();
    48       Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, "The impact values calculator to be used for figuring out the node impacts.", impactValuesCalculator));
     45    public SymbolicRegressionPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator)
     46      : base(impactValuesCalculator) {
    4947    }
    5048
    51     protected override ISymbolicDataAnalysisModel CreateModel() {
    52       return new SymbolicRegressionModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper);
     49    protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
     50      return new SymbolicRegressionModel(tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
    5351    }
    5452
     
    5654      var regressionModel = (IRegressionModel)model;
    5755      var regressionProblemData = (IRegressionProblemData)ProblemData;
    58       var trainingIndices = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
    59       var estimatedValues = regressionModel.GetEstimatedValues(ProblemData.Dataset, trainingIndices); // also bounds the values
    60       var targetValues = ProblemData.Dataset.GetDoubleValues(regressionProblemData.TargetVariable, trainingIndices);
     56      var rows = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
     57      return Evaluate(regressionModel, regressionProblemData, rows);
     58    }
     59
     60    private static double Evaluate(IRegressionModel model, IRegressionProblemData problemData,
     61      IEnumerable<int> rows) {
     62      var estimatedValues = model.GetEstimatedValues(problemData.Dataset, rows); // also bounds the values
     63      var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
    6164      OnlineCalculatorError errorState;
    6265      var quality = OnlinePearsonsRSquaredCalculator.Calculate(targetValues, estimatedValues, out errorState);
     
    6467      return quality;
    6568    }
     69
     70    public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, SymbolicRegressionSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IRegressionProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows, double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) {
     71      var clonedTree = (ISymbolicExpressionTree)tree.Clone();
     72      var model = new SymbolicRegressionModel(clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     73      var nodes = clonedTree.IterateNodesPrefix().ToList();
     74      double quality = Evaluate(model, problemData, rows);
     75
     76      for (int i = 0; i < nodes.Count; ++i) {
     77        var node = nodes[i];
     78        if (node is ConstantTreeNode) continue;
     79
     80        double impactValue, replacementValue;
     81        impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, quality);
     82
     83        if (pruneOnlyZeroImpactNodes) {
     84          if (!impactValue.IsAlmost(0.0)) continue;
     85        } else if (nodeImpactThreshold < impactValue) {
     86          continue;
     87        }
     88
     89        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
     90        constantNode.Value = replacementValue;
     91
     92        ReplaceWithConstant(node, constantNode);
     93        i += node.GetLength() - 1; // skip subtrees under the node that was folded
     94
     95        quality -= impactValue;
     96      }
     97      return model.SymbolicExpressionTree;
     98    }
    6699  }
    67100}
  • trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic/3.4/SymbolicDataAnalysisExpressionPruningOperator.cs

    r12012 r12189  
    109109      : base(original, cloner) { }
    110110
    111     protected SymbolicDataAnalysisExpressionPruningOperator() {
     111    protected SymbolicDataAnalysisExpressionPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator) {
    112112      #region add parameters
    113113      Parameters.Add(new LookupParameter<IDataAnalysisProblemData>(ProblemDataParameterName));
     
    122122      Parameters.Add(new LookupParameter<ISymbolicExpressionTree>(SymbolicExpressionTreeParameterName));
    123123      Parameters.Add(new LookupParameter<DoubleValue>(QualityParameterName));
     124      Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, impactValuesCalculator));
    124125      #endregion
    125126    }
    126127
    127     protected abstract ISymbolicDataAnalysisModel CreateModel();
     128    protected abstract ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits);
    128129
    129130    protected abstract double Evaluate(IDataAnalysisModel model);
    130131
    131132    public override IOperation Apply() {
    132       var model = CreateModel();
     133      var model = CreateModel(SymbolicExpressionTree, Interpreter, ProblemData, EstimationLimits);
    133134      var nodes = SymbolicExpressionTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList();
    134135      var rows = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
     
    169170    }
    170171
    171     private static void ReplaceWithConstant(ISymbolicExpressionTreeNode original, ISymbolicExpressionTreeNode replacement) {
     172    public ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
     173      var model = CreateModel((ISymbolicExpressionTree)tree.Clone(), Interpreter, ProblemData, EstimationLimits);
     174      var nodes = SymbolicExpressionTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList();
     175      var rows = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
     176
     177      double quality = Evaluate(model);
     178
     179      for (int i = 0; i < nodes.Count; ++i) {
     180        var node = nodes[i];
     181        if (node is ConstantTreeNode) continue;
     182
     183        double impactValue, replacementValue;
     184        ImpactValuesCalculator.CalculateImpactAndReplacementValues(model, node, ProblemData, rows, out impactValue, out replacementValue, quality);
     185
     186        if (PruneOnlyZeroImpactNodes) {
     187          if (!impactValue.IsAlmost(0.0)) continue;
     188        } else if (NodeImpactThreshold < impactValue) {
     189          continue;
     190        }
     191
     192        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
     193        constantNode.Value = replacementValue;
     194
     195        ReplaceWithConstant(node, constantNode);
     196        i += node.GetLength() - 1; // skip subtrees under the node that was folded
     197
     198        quality -= impactValue;
     199      }
     200      return model.SymbolicExpressionTree;
     201    }
     202
     203    protected static void ReplaceWithConstant(ISymbolicExpressionTreeNode original, ISymbolicExpressionTreeNode replacement) {
    172204      var parent = original.Parent;
    173205      var i = parent.IndexOfSubtree(original);
Note: See TracChangeset for help on using the changeset viewer.