Free cookie consent management tool by TermsFeed Policy Generator

source: branches/2925_AutoDiffForDynamicalModels/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs @ 17187

Last change on this file since 17187 was 16662, checked in by gkronber, 6 years ago

#2925: merged all changes from trunk to branch (up to r16659)

File size: 6.3 KB
RevLine 
[11025]1#region License Information
2
3/* HeuristicLab
[16662]4 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[11025]5 *
6 * This file is part of HeuristicLab.
7 *
8 * HeuristicLab is free software: you can redistribute it and/or modify
9 * it under the terms of the GNU General Public License as published by
10 * the Free Software Foundation, either version 3 of the License, or
11 * (at your option) any later version.
12 *
13 * HeuristicLab is distributed in the hope that it will be useful,
14 * but WITHOUT ANY WARRANTY; without even the implied warranty of
15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
16 * GNU General Public License for more details.
17 *
18 * You should have received a copy of the GNU General Public License
19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
20 */
21
22#endregion
23
[12189]24using System.Collections.Generic;
[11025]25using System.Linq;
[10469]26using HeuristicLab.Common;
27using HeuristicLab.Core;
[12189]28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
[10469]29using HeuristicLab.Parameters;
[16662]30using HEAL.Attic;
[10469]31
32namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
[16662]33  [StorableType("F213025F-ACE7-4E43-A149-70C3AD824D19")]
[10469]34  [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")]
35  public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
36    private const string ModelCreatorParameterName = "ModelCreator";
[12720]37    private const string EvaluatorParameterName = "Evaluator";
[10469]38
39    #region parameter properties
40    public ILookupParameter<ISymbolicClassificationModelCreator> ModelCreatorParameter {
41      get { return (ILookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; }
42    }
[12720]43
44    public ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator> EvaluatorParameter {
45      get {
46        return (ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName];
47      }
48    }
[10469]49    #endregion
50
[12358]51    protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner) : base(original, cloner) { }
52    public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationPruningOperator(this, cloner); }
[11025]53
[10469]54    [StorableConstructor]
[16662]55    protected SymbolicClassificationPruningOperator(StorableConstructorFlag _) : base(_) { }
[10469]56
[12189]57    public SymbolicClassificationPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator)
58      : base(impactValuesCalculator) {
[10469]59      Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName));
[12720]60      Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
[10469]61    }
62
[12744]63    [StorableHook(HookType.AfterDeserialization)]
64    private void AfterDeserialization() {
65      // BackwardsCompatibility3.3
66      #region Backwards compatible code, remove with 3.4
67      base.ImpactValuesCalculator = new SymbolicClassificationSolutionImpactValuesCalculator();
68      if (!Parameters.ContainsKey(EvaluatorParameterName)) {
69        Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName));
70      }
71      #endregion
72    }
73
[12189]74    protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) {
75      var classificationProblemData = (IClassificationProblemData)problemData;
[13941]76      var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(classificationProblemData.TargetVariable, tree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
77
[12189]78      var rows = classificationProblemData.TrainingIndices;
79      model.RecalculateModelParameters(classificationProblemData, rows);
[10469]80      return model;
81    }
82
83    protected override double Evaluate(IDataAnalysisModel model) {
[12720]84      var evaluator = EvaluatorParameter.ActualValue;
85      var classificationModel = (ISymbolicClassificationModel)model;
[12358]86      var classificationProblemData = (IClassificationProblemData)ProblemDataParameter.ActualValue;
87      var rows = Enumerable.Range(FitnessCalculationPartitionParameter.ActualValue.Start, FitnessCalculationPartitionParameter.ActualValue.Size);
[12720]88      return evaluator.Evaluate(this.ExecutionContext, classificationModel.SymbolicExpressionTree, classificationProblemData, rows);
[12189]89    }
90
91    public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicClassificationModelCreator modelCreator,
92      SymbolicClassificationSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter,
93      IClassificationProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows,
94      double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) {
95      var clonedTree = (ISymbolicExpressionTree)tree.Clone();
[13941]96      var model = modelCreator.CreateSymbolicClassificationModel(problemData.TargetVariable, clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper);
[12189]97
[12461]98      var nodes = clonedTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList();
[12720]99      double qualityForImpactsCalculation = double.NaN;
[12189]100
101      for (int i = 0; i < nodes.Count; ++i) {
102        var node = nodes[i];
103        if (node is ConstantTreeNode) continue;
104
[12720]105        double impactValue, replacementValue, newQualityForImpactsCalculation;
106        impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation);
[12189]107
[12358]108        if (pruneOnlyZeroImpactNodes && !impactValue.IsAlmost(0.0)) continue;
109        if (!pruneOnlyZeroImpactNodes && impactValue > nodeImpactThreshold) continue;
[12189]110
111        var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode();
112        constantNode.Value = replacementValue;
113
114        ReplaceWithConstant(node, constantNode);
115        i += node.GetLength() - 1; // skip subtrees under the node that was folded
116
[12720]117        qualityForImpactsCalculation = newQualityForImpactsCalculation;
[12189]118      }
119      return model.SymbolicExpressionTree;
120    }
[10469]121  }
122}
Note: See TracBrowser for help on using the repository browser.