Changeset 12745 for stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs
Legend:
- Unmodified
- Added
- Removed
-
stable
- Property svn:mergeinfo changed
/trunk/sources merged: 12189,12358-12359,12361,12461,12674,12720,12744
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification
- Property svn:mergeinfo changed
/trunk/sources/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification merged: 12189,12358,12461,12720,12744
- Property svn:mergeinfo changed
-
stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs
r12009 r12745 22 22 #endregion 23 23 24 using System.Collections.Generic; 24 25 using System.Linq; 25 26 using HeuristicLab.Common; 26 27 using HeuristicLab.Core; 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 27 29 using HeuristicLab.Parameters; 28 30 using HeuristicLab.Persistence.Default.CompositeSerializers.Storable; … … 32 34 [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")] 33 35 public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator { 34 private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";35 36 private const string ModelCreatorParameterName = "ModelCreator"; 37 private const string EvaluatorParameterName = "Evaluator"; 36 38 37 39 #region parameter properties … … 39 41 get { return (ILookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; } 40 42 } 43 44 public ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator> EvaluatorParameter { 45 get { 46 return (ILookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; 47 } 48 } 41 49 #endregion 42 50 43 protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner) 44 : base(original, cloner) { 45 } 46 47 public override IDeepCloneable Clone(Cloner cloner) { 48 return new SymbolicClassificationPruningOperator(this, cloner); 49 } 51 protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner) : base(original, cloner) { } 52 public override IDeepCloneable Clone(Cloner cloner) { return new SymbolicClassificationPruningOperator(this, cloner); } 50 53 51 54 [StorableConstructor] 52 55 protected SymbolicClassificationPruningOperator(bool deserializing) : base(deserializing) { } 53 56 54 public SymbolicClassificationPruningOperator( ) {55 Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, new SymbolicClassificationSolutionImpactValuesCalculator()));57 public SymbolicClassificationPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator) 58 : base(impactValuesCalculator) { 56 59 Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName)); 60 Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName)); 57 61 } 58 62 59 protected override ISymbolicDataAnalysisModel CreateModel() { 60 var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper); 61 var problemData = (IClassificationProblemData)ProblemData; 62 var rows = problemData.TrainingIndices; 63 model.RecalculateModelParameters(problemData, rows); 63 [StorableHook(HookType.AfterDeserialization)] 64 private void AfterDeserialization() { 65 // BackwardsCompatibility3.3 66 #region Backwards compatible code, remove with 3.4 67 base.ImpactValuesCalculator = new SymbolicClassificationSolutionImpactValuesCalculator(); 68 if (!Parameters.ContainsKey(EvaluatorParameterName)) { 69 Parameters.Add(new LookupParameter<ISymbolicClassificationSingleObjectiveEvaluator>(EvaluatorParameterName)); 70 } 71 #endregion 72 } 73 74 protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) { 75 var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(tree, interpreter, estimationLimits.Lower, estimationLimits.Upper); 76 var classificationProblemData = (IClassificationProblemData)problemData; 77 var rows = classificationProblemData.TrainingIndices; 78 model.RecalculateModelParameters(classificationProblemData, rows); 64 79 return model; 65 80 } 66 81 67 82 protected override double Evaluate(IDataAnalysisModel model) { 68 var classificationModel = (IClassificationModel)model; 69 var classificationProblemData = (IClassificationProblemData)ProblemData; 70 var trainingIndices = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size); 71 var estimatedValues = classificationModel.GetEstimatedClassValues(ProblemData.Dataset, trainingIndices); 72 var targetValues = ProblemData.Dataset.GetDoubleValues(classificationProblemData.TargetVariable, trainingIndices); 73 OnlineCalculatorError errorState; 74 var quality = OnlineAccuracyCalculator.Calculate(targetValues, estimatedValues, out errorState); 75 if (errorState != OnlineCalculatorError.None) return double.NaN; 76 return quality; 83 var evaluator = EvaluatorParameter.ActualValue; 84 var classificationModel = (ISymbolicClassificationModel)model; 85 var classificationProblemData = (IClassificationProblemData)ProblemDataParameter.ActualValue; 86 var rows = Enumerable.Range(FitnessCalculationPartitionParameter.ActualValue.Start, FitnessCalculationPartitionParameter.ActualValue.Size); 87 return evaluator.Evaluate(this.ExecutionContext, classificationModel.SymbolicExpressionTree, classificationProblemData, rows); 88 } 89 90 public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, ISymbolicClassificationModelCreator modelCreator, 91 SymbolicClassificationSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, 92 IClassificationProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows, 93 double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) { 94 var clonedTree = (ISymbolicExpressionTree)tree.Clone(); 95 var model = modelCreator.CreateSymbolicClassificationModel(clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper); 96 97 var nodes = clonedTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList(); 98 double qualityForImpactsCalculation = double.NaN; 99 100 for (int i = 0; i < nodes.Count; ++i) { 101 var node = nodes[i]; 102 if (node is ConstantTreeNode) continue; 103 104 double impactValue, replacementValue, newQualityForImpactsCalculation; 105 impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation); 106 107 if (pruneOnlyZeroImpactNodes && !impactValue.IsAlmost(0.0)) continue; 108 if (!pruneOnlyZeroImpactNodes && impactValue > nodeImpactThreshold) continue; 109 110 var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode(); 111 constantNode.Value = replacementValue; 112 113 ReplaceWithConstant(node, constantNode); 114 i += node.GetLength() - 1; // skip subtrees under the node that was folded 115 116 qualityForImpactsCalculation = newQualityForImpactsCalculation; 117 } 118 return model.SymbolicExpressionTree; 77 119 } 78 120 }
Note: See TracChangeset
for help on using the changeset viewer.