Free cookie consent management tool by TermsFeed Policy Generator

Ignore:
Timestamp:
07/12/15 11:23:06 (9 years ago)
Author:
gkronber
Message:

#2359, #2398: merged r12189,r12358,r12359,r12361,r12461,r12674,r12720,r12744 from trunk to stable

Location:
stable
Files:
3 edited

Legend:

Unmodified
Added
Removed
  • stable

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification

  • stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs

    r12009 r12745  
    2222#endregion
    2323
     24using System.Collections.Generic;
    2425using System.Linq;
    2526using HeuristicLab.Common;
    2627using HeuristicLab.Core;
     28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
    2729using HeuristicLab.Parameters;
    2830using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
     
    3234  [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")]
    3335  public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
    34     private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
    3536    private const string ModelCreatorParameterName = "ModelCreator";
     37    private const string EvaluatorParameterName = "Evaluator";
    3638
    3739    #region parameter properties
     
    3941      get { return (ILookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; }
    4042    }
     43
     44    public ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator> EvaluatorParameter {
     45      get {
     46        return (ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName];
     47      }
     48    }
    4149    #endregion
    4250
    43     protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner)
    44       : base(original, cloner) {
    45     }
    46 
    47     public override IDeepCloneable Clone(Cloner cloner) {
    48       return new SymbolicClassificationPruningOperator(this, cloner);
    49     }
     51    protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner) : base(original, cloner) { }
     52    public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationPruningOperator(this, cloner); }
    5053
    5154    [StorableConstructor]
    5255    protected SymbolicClassificationPruningOperator(bool deserializing) : base(deserializing) { }
    5356
    54     public SymbolicClassificationPruningOperator() {
    55       Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, new SymbolicClassificationSolutionImpactValuesCalculator()));
     57    public SymbolicClassificationPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator)
     58      : base(impactValuesCalculator) {
    5659      Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName));
     60      Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
    5761    }
    5862
    59     protected override ISymbolicDataAnalysisModel CreateModel() {
    60       var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper);
    61       var problemData = (IClassificationProblemData)ProblemData;
    62       var rows = problemData.TrainingIndices;
    63       model.RecalculateModelParameters(problemData, rows);
     63    [StorableHook(HookType.AfterDeserialization)]
     64    private void AfterDeserialization() {
     65      // BackwardsCompatibility3.3
     66      #region Backwards compatible code, remove with 3.4
     67      base.ImpactValuesCalculator = new SymbolicClassificationSolutionImpactValuesCalculator();
     68      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
     69        Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
     70      }
     71      #endregion
     72    }
     73
     74    protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
     75      var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     76      var classificationProblemData = (IClassificationProblemData)problemData;
     77      var rows = classificationProblemData.TrainingIndices;
     78      model.RecalculateModelParameters(classificationProblemData, rows);
    6479      return model;
    6580    }
    6681
    6782    protected override double Evaluate(IDataAnalysisModel model) {
    68       var classificationModel = (IClassificationModel)model;
    69       var classificationProblemData = (IClassificationProblemData)ProblemData;
    70       var trainingIndices = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
    71       var estimatedValues = classificationModel.GetEstimatedClassValues(ProblemData.Dataset, trainingIndices);
    72       var targetValues = ProblemData.Dataset.GetDoubleValues(classificationProblemData.TargetVariable, trainingIndices);
    73       OnlineCalculatorError errorState;
    74       var quality = OnlineAccuracyCalculator.Calculate(targetValues, estimatedValues, out errorState);
    75       if (errorState != OnlineCalculatorError.None) return double.NaN;
    76       return quality;
     83      var evaluator = EvaluatorParameter.ActualValue;
     84      var classificationModel = (ISymbolicClassificationModel)model;
     85      var classificationProblemData = (IClassificationProblemData)ProblemDataParameter.ActualValue;
     86      var rows = Enumerable.Range(FitnessCalculationPartitionParameter.ActualValue.Start, FitnessCalculationPartitionParameter.ActualValue.Size);
     87      return evaluator.Evaluate(this.ExecutionContext, classificationModel.SymbolicExpressionTree, classificationProblemData, rows);
     88    }
     89
     90    public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicClassificationModelCreator modelCreator,
     91      SymbolicClassificationSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
     92      IClassificationProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows,
     93      double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) {
     94      var clonedTree = (ISymbolicExpressionTree)tree.Clone();
     95      var model = modelCreator.CreateSymbolicClassificationModel(clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
     96
     97      var nodes = clonedTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList();
     98      double qualityForImpactsCalculation = double.NaN;
     99
     100      for (int i = 0; i < nodes.Count; ++i) {
     101        var node = nodes[i];
     102        if (node is ConstantTreeNode) continue;
     103
     104        double impactValue, replacementValue, newQualityForImpactsCalculation;
     105        impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation);
     106
     107        if (pruneOnlyZeroImpactNodes && !impactValue.IsAlmost(0.0)) continue;
     108        if (!pruneOnlyZeroImpactNodes && impactValue > nodeImpactThreshold) continue;
     109
     110        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
     111        constantNode.Value = replacementValue;
     112
     113        ReplaceWithConstant(node, constantNode);
     114        i += node.GetLength() - 1; // skip subtrees under the node that was folded
     115
     116        qualityForImpactsCalculation = newQualityForImpactsCalculation;
     117      }
     118      return model.SymbolicExpressionTree;
    77119    }
    78120  }
Note: See TracChangeset for help on using the changeset viewer.