Ignore:
Timestamp:
07/12/15 11:23:06 (7 years ago)
Author:
gkronber
Message:

#2359, #2398: merged r12189,r12358,r12359,r12361,r12461,r12674,r12720,r12744 from trunk to stable

Location:
stable
Files:
5 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningAnalyzer.cs

    r12009 r12745  
    2222using HeuristicLab.Common;
    2323using HeuristicLab.Core;
     24using HeuristicLab.Data;
    2425using HeuristicLab.Parameters;
    2526using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    2930  [StorableClass]
    3031  public sealed class SymbolicClassificationPruningAnalyzer : SymbolicDataAnalysisSingleObjectivePruningAnalyzer {
    31     private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
    3232    private const string PruningOperatorParameterName = "PruningOperator";
    33     private SymbolicClassificationPruningAnalyzer(SymbolicClassificationPruningAnalyzer original, Cloner cloner)
    34       : base(original, cloner) {
     33    public IValueParameter<SymbolicClassificationPruningOperator> PruningOperatorParameter {
     34      get { return (IValueParameter<SymbolicClassificationPruningOperator>)Parameters[PruningOperatorParameterName]; }
    3535    }
    36     public override IDeepCloneable Clone(Cloner cloner) {
    37       return new SymbolicClassificationPruningAnalyzer(this, cloner);
     36
     37    protected override SymbolicDataAnalysisExpressionPruningOperator PruningOperator {
     38      get { return PruningOperatorParameter.Value; }
    3839    }
     40
     41    private SymbolicClassificationPruningAnalyzer(SymbolicClassificationPruningAnalyzer original, Cloner cloner) : base(original, cloner) { }
     42    public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationPruningAnalyzer(this, cloner); }
    3943
    4044    [StorableConstructor]
     
    4246
    4347    public SymbolicClassificationPruningAnalyzer() {
    44       Parameters.Add(new ValueParameter<SymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, "The impact values calculator", new SymbolicClassificationSolutionImpactValuesCalculator()));
    45       Parameters.Add(new ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator()));
     48      Parameters.Add(new ValueParameter<SymbolicClassificationPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator(new SymbolicClassificationSolutionImpactValuesCalculator())));
     49    }
     50
     51    [StorableHook(HookType.AfterDeserialization)]
     52    private void AfterDeserialization() {
     53      // BackwardsCompatibility3.3
     54
     55      #region Backwards compatible code, remove with 3.4
     56      if (Parameters.ContainsKey(PruningOperatorParameterName)) {
     57        var oldParam = Parameters[PruningOperatorParameterName] as ValueParameter<SymbolicDataAnalysisExpressionPruningOperator>;
     58        if (oldParam != null) {
     59          Parameters.Remove(oldParam);
     60          Parameters.Add(new ValueParameter<SymbolicClassificationPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator(new SymbolicClassificationSolutionImpactValuesCalculator())));
     61        }
     62      } else {
     63        // not yet contained
     64        Parameters.Add(new ValueParameter<SymbolicClassificationPruningOperator>(PruningOperatorParameterName, "The operator used to prune trees", new SymbolicClassificationPruningOperator(new SymbolicClassificationSolutionImpactValuesCalculator())));
     65      }
     66
     67      if (Parameters.ContainsKey("PruneOnlyZeroImpactNodes")) {
     68        PruningOperator.PruneOnlyZeroImpactNodes = ((IFixedValueParameter<BoolValue>)Parameters["PruneOnlyZeroImpactNodes"]).Value.Value;
     69        Parameters.Remove(Parameters["PruneOnlyZeroImpactNodes"]);
     70      }
     71      if (Parameters.ContainsKey("ImpactThreshold")) {
     72        PruningOperator.NodeImpactThreshold = ((IFixedValueParameter<DoubleValue>)Parameters["ImpactThreshold"]).Value.Value;
     73        Parameters.Remove(Parameters["ImpactThreshold"]);
     74      }
     75      if (Parameters.ContainsKey("ImpactValuesCalculator")) {
     76        PruningOperator.ImpactValuesCalculator = ((ValueParameter<SymbolicDataAnalysisSolutionImpactValuesCalculator>)Parameters["ImpactValuesCalculator"]).Value;
     77        Parameters.Remove(Parameters["ImpactValuesCalculator"]);
     78      }
     79      #endregion
    4680    }
    4781  }
  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs

    r12009 r12745  
    2222#endregion
    2323
     24using System.Collections.Generic;
    2425using System.Linq;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
     28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2729using HeuristicLab.Parameters;
    2830using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    3234  [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")]
    3335  public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
    34     private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
    3536    private const string ModelCreatorParameterName = "ModelCreator";
     37    private const string EvaluatorParameterName = "Evaluator";
    3638
    3739    #region parameter properties
     
    3941      get { return (ILookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; }
    4042    }
     43
     44    public ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator> EvaluatorParameter {
     45      get {
     46        return (ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName];
     47      }
     48    }
    4149    #endregion
    4250
    43     protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner)
    44       : base(original, cloner) {
    45     }
    46 
    47     public override IDeepCloneable Clone(Cloner cloner) {
    48       return new SymbolicClassificationPruningOperator(this, cloner);
    49     }
     51    protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner) : base(original, cloner) { }
     52    public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationPruningOperator(this, cloner); }
    5053
    5154    [StorableConstructor]
    5255    protected SymbolicClassificationPruningOperator(bool deserializing) : base(deserializing) { }
    5356
    54     public SymbolicClassificationPruningOperator() {
    55       Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, new SymbolicClassificationSolutionImpactValuesCalculator()));
     57    public SymbolicClassificationPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator)
     58      : base(impactValuesCalculator) {
    5659      Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName));
     60      Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
    5761    }
    5862
    59     protected override ISymbolicDataAnalysisModel CreateModel() {
    60       var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper);
    61       var problemData = (IClassificationProblemData)ProblemData;
    62       var rows = problemData.TrainingIndices;
    63       model.RecalculateModelParameters(problemData, rows);
     63    [StorableHook(HookType.AfterDeserialization)]
     64    private void AfterDeserialization() {
     65      // BackwardsCompatibility3.3
     66      #region Backwards compatible code, remove with 3.4
     67      base.ImpactValuesCalculator = new SymbolicClassificationSolutionImpactValuesCalculator();
     68      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
     69        Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
     70      }
     71      #endregion
     72    }
     73
     74    protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
     75      var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     76      var classificationProblemData = (IClassificationProblemData)problemData;
     77      var rows = classificationProblemData.TrainingIndices;
     78      model.RecalculateModelParameters(classificationProblemData, rows);
    6479      return model;
    6580    }
    6681
    6782    protected override double Evaluate(IDataAnalysisModel model) {
    68       var classificationModel = (IClassificationModel)model;
    69       var classificationProblemData = (IClassificationProblemData)ProblemData;
    70       var trainingIndices = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
    71       var estimatedValues = classificationModel.GetEstimatedClassValues(ProblemData.Dataset, trainingIndices);
    72       var targetValues = ProblemData.Dataset.GetDoubleValues(classificationProblemData.TargetVariable, trainingIndices);
    73       OnlineCalculatorError errorState;
    74       var quality = OnlineAccuracyCalculator.Calculate(targetValues, estimatedValues, out errorState);
    75       if (errorState != OnlineCalculatorError.None) return double.NaN;
    76       return quality;
     83      var evaluator = EvaluatorParameter.ActualValue;
     84      var classificationModel = (ISymbolicClassificationModel)model;
     85      var classificationProblemData = (IClassificationProblemData)ProblemDataParameter.ActualValue;
     86      var rows = Enumerable.Range(FitnessCalculationPartitionParameter.ActualValue.Start, FitnessCalculationPartitionParameter.ActualValue.Size);
     87      return evaluator.Evaluate(this.ExecutionContext, classificationModel.SymbolicExpressionTree, classificationProblemData, rows);
     88    }
     89
     90    public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicClassificationModelCreator modelCreator,
     91      SymbolicClassificationSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     92      IClassificationProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows,
     93      double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) {
     94      var clonedTree = (ISymbolicExpressionTree)tree.Clone();
     95      var model = modelCreator.CreateSymbolicClassificationModel(clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     96
     97      var nodes = clonedTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList();
     98      double qualityForImpactsCalculation = double.NaN;
     99
     100      for (int i = 0; i < nodes.Count; ++i) {
     101        var node = nodes[i];
     102        if (node is ConstantTreeNode) continue;
     103
     104        double impactValue, replacementValue, newQualityForImpactsCalculation;
     105        impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation);
     106
     107        if (pruneOnlyZeroImpactNodes && !impactValue.IsAlmost(0.0)) continue;
     108        if (!pruneOnlyZeroImpactNodes && impactValue > nodeImpactThreshold) continue;
     109
     110        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
     111        constantNode.Value = replacementValue;
     112
     113        ReplaceWithConstant(node, constantNode);
     114        i += node.GetLength() - 1; // skip subtrees under the node that was folded
     115
     116        qualityForImpactsCalculation = newQualityForImpactsCalculation;
     117      }
     118      return model.SymbolicExpressionTree;
    77119    }
    78120  }
  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationSolutionImpactValuesCalculator.cs

    r12009 r12745  
    4747    }
    4848
    49     public override double CalculateImpactValue(ISymbolicDataAnalysisModel model, ISymbolicExpressionTreeNode node, IDataAnalysisProblemData problemData, IEnumerable<int> rows, double originalQuality = double.NaN) {
     49    public override double CalculateImpactValue(ISymbolicDataAnalysisModel model, ISymbolicExpressionTreeNode node, IDataAnalysisProblemData problemData, IEnumerable<int> rows, double qualityForImpactsCalculation = double.NaN) {
    5050      double impactValue, replacementValue;
    51       CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, originalQuality);
     51      double newQualityForImpactsCalculation;
     52      CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation);
    5253      return impactValue;
    5354    }
    5455
    5556    public override void CalculateImpactAndReplacementValues(ISymbolicDataAnalysisModel model, ISymbolicExpressionTreeNode node,
    56       IDataAnalysisProblemData problemData, IEnumerable<int> rows, out double impactValue, out double replacementValue,
    57       double originalQuality = Double.NaN) {
     57      IDataAnalysisProblemData problemData, IEnumerable<int> rows, out double impactValue, out double replacementValue, out double newQualityForImpactsCalculation,
     58      double qualityForImpactsCalculation = Double.NaN) {
    5859      var classificationModel = (ISymbolicClassificationModel)model;
    5960      var classificationProblemData = (IClassificationProblemData)problemData;
    6061
    61       var dataset = classificationProblemData.Dataset;
    62       var targetClassValues = dataset.GetDoubleValues(classificationProblemData.TargetVariable, rows);
    63 
    64       OnlineCalculatorError errorState;
    65       if (double.IsNaN(originalQuality)) {
    66         var originalClassValues = classificationModel.GetEstimatedClassValues(dataset, rows);
    67         originalQuality = OnlineAccuracyCalculator.Calculate(targetClassValues, originalClassValues, out errorState);
    68         if (errorState != OnlineCalculatorError.None) originalQuality = 0.0;
    69       }
     62      if (double.IsNaN(qualityForImpactsCalculation))
     63        qualityForImpactsCalculation = CalculateQualityForImpacts(classificationModel, classificationProblemData, rows);
    7064
    7165      replacementValue = CalculateReplacementValue(classificationModel, node, classificationProblemData, rows);
     
    8175      tempModelParentNode.InsertSubtree(i, constantNode);
    8276
     77      OnlineCalculatorError errorState;
     78      var dataset = classificationProblemData.Dataset;
     79      var targetClassValues = dataset.GetDoubleValues(classificationProblemData.TargetVariable, rows);
    8380      var estimatedClassValues = tempModel.GetEstimatedClassValues(dataset, rows);
    84       double newQuality = OnlineAccuracyCalculator.Calculate(targetClassValues, estimatedClassValues, out errorState);
    85       if (errorState != OnlineCalculatorError.None) newQuality = 0.0;
     81      newQualityForImpactsCalculation = OnlineAccuracyCalculator.Calculate(targetClassValues, estimatedClassValues, out errorState);
     82      if (errorState != OnlineCalculatorError.None) newQualityForImpactsCalculation = 0.0;
    8683
    87       impactValue = originalQuality - newQuality;
     84      impactValue = qualityForImpactsCalculation - newQualityForImpactsCalculation;
     85    }
     86
     87    public static double CalculateQualityForImpacts(ISymbolicClassificationModel model, IClassificationProblemData problemData, IEnumerable<int> rows) {
     88      OnlineCalculatorError errorState;
     89      var dataset = problemData.Dataset;
     90      var targetClassValues = dataset.GetDoubleValues(problemData.TargetVariable, rows);
     91      var originalClassValues = model.GetEstimatedClassValues(dataset, rows);
     92      var qualityForImpactsCalculation = OnlineAccuracyCalculator.Calculate(targetClassValues, originalClassValues, out errorState);
     93      if (errorState != OnlineCalculatorError.None) qualityForImpactsCalculation = 0.0;
     94
     95      return qualityForImpactsCalculation;
    8896    }
    8997  }
Note: See TracChangeset for help on using the changeset viewer.