Changeset 17687 for branches/1837_Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionPruningOperator.cs
- Timestamp:
- 07/19/20 19:07:40 (4 years ago)
- File:
-
- 1 edited
Legend:
- Unmodified
- Added
- Removed
-
branches/1837_Sliding Window GP/HeuristicLab.Problems.DataAnalysis.Symbolic.Regression/3.4/SymbolicRegressionPruningOperator.cs
r10469 r17687 1 using System.Linq; 1 #region License Information 2 3 /* HeuristicLab 4 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL) 5 * 6 * This file is part of HeuristicLab. 7 * 8 * HeuristicLab is free software: you can redistribute it and/or modify 9 * it under the terms of the GNU General Public License as published by 10 * the Free Software Foundation, either version 3 of the License, or 11 * (at your option) any later version. 12 * 13 * HeuristicLab is distributed in the hope that it will be useful, 14 * but WITHOUT ANY WARRANTY; without even the implied warranty of 15 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 16 * GNU General Public License for more details. 17 * 18 * You should have received a copy of the GNU General Public License 19 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>. 20 */ 21 22 #endregion 23 24 using System.Collections.Generic; 25 using System.Linq; 2 26 using HeuristicLab.Common; 3 27 using HeuristicLab.Core; 28 using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding; 4 29 using HeuristicLab.Parameters; 5 using H euristicLab.Persistence.Default.CompositeSerializers.Storable;30 using HEAL.Attic; 6 31 7 32 namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Regression { 8 [Storable Class]33 [StorableType("75843B4E-C69C-423A-87BD-A64619D380BB")] 9 34 [Item("SymbolicRegressionPruningOperator", "An operator which prunes symbolic regression trees.")] 10 35 public class SymbolicRegressionPruningOperator : SymbolicDataAnalysisExpressionPruningOperator { 11 private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";12 private const string ImpactValuesCalculatorParameterDescription = "The impact values calculator to be used for figuring out the node impacts.";13 14 36 private const string EvaluatorParameterName = "Evaluator"; 15 37 38 #region parameter properties 16 39 public ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator> EvaluatorParameter { 17 40 get { return (ILookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>)Parameters[EvaluatorParameterName]; } 18 41 } 42 #endregion 19 43 20 44 protected SymbolicRegressionPruningOperator(SymbolicRegressionPruningOperator original, Cloner cloner) … … 26 50 27 51 [StorableConstructor] 28 protected SymbolicRegressionPruningOperator( bool deserializing) : base(deserializing) { }52 protected SymbolicRegressionPruningOperator(StorableConstructorFlag _) : base(_) { } 29 53 30 public SymbolicRegressionPruningOperator() { 31 var impactValuesCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); 32 Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, ImpactValuesCalculatorParameterDescription, impactValuesCalculator)); 54 public SymbolicRegressionPruningOperator(ISymbolicDataAnalysisSolutionImpactValuesCalculator impactValuesCalculator) 55 : base(impactValuesCalculator) { 33 56 Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName)); 34 57 } 35 58 36 protected override ISymbolicDataAnalysisModel CreateModel() { 37 return new SymbolicRegressionModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper); 59 [StorableHook(HookType.AfterDeserialization)] 60 private void AfterDeserialization() { 61 // BackwardsCompatibility3.3 62 #region Backwards compatible code, remove with 3.4 63 base.ImpactValuesCalculator = new SymbolicRegressionSolutionImpactValuesCalculator(); 64 if (!Parameters.ContainsKey(EvaluatorParameterName)) { 65 Parameters.Add(new LookupParameter<ISymbolicRegressionSingleObjectiveEvaluator>(EvaluatorParameterName)); 66 } 67 #endregion 68 } 69 70 protected override ISymbolicDataAnalysisModel CreateModel(ISymbolicExpressionTree tree, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IDataAnalysisProblemData problemData, DoubleLimit estimationLimits) { 71 var regressionProblemData = (IRegressionProblemData)problemData; 72 return new SymbolicRegressionModel(regressionProblemData.TargetVariable, tree, interpreter, estimationLimits.Lower, estimationLimits.Upper); 38 73 } 39 74 40 75 protected override double Evaluate(IDataAnalysisModel model) { 41 var regressionModel = (IRegressionModel)model; 42 var regressionProblemData = (IRegressionProblemData)ProblemData; 43 var trainingIndices = ProblemData.TrainingIndices.ToList(); 44 var estimatedValues = regressionModel.GetEstimatedValues(ProblemData.Dataset, trainingIndices); // also bounds the values 45 var targetValues = ProblemData.Dataset.GetDoubleValues(regressionProblemData.TargetVariable, trainingIndices); 46 OnlineCalculatorError errorState; 47 var quality = OnlinePearsonsRSquaredCalculator.Calculate(targetValues, estimatedValues, out errorState); 48 if (errorState != OnlineCalculatorError.None) return double.NaN; 49 return quality; 76 var regressionModel = (ISymbolicRegressionModel)model; 77 var regressionProblemData = (IRegressionProblemData)ProblemDataParameter.ActualValue; 78 var evaluator = EvaluatorParameter.ActualValue; 79 var fitnessEvaluationPartition = FitnessCalculationPartitionParameter.ActualValue; 80 var rows = Enumerable.Range(fitnessEvaluationPartition.Start, fitnessEvaluationPartition.Size); 81 return evaluator.Evaluate(this.ExecutionContext, regressionModel.SymbolicExpressionTree, regressionProblemData, rows); 82 } 83 84 public static ISymbolicExpressionTree Prune(ISymbolicExpressionTree tree, SymbolicRegressionSolutionImpactValuesCalculator impactValuesCalculator, ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, IRegressionProblemData problemData, DoubleLimit estimationLimits, IEnumerable<int> rows, double nodeImpactThreshold = 0.0, bool pruneOnlyZeroImpactNodes = false) { 85 var clonedTree = (ISymbolicExpressionTree)tree.Clone(); 86 var model = new SymbolicRegressionModel(problemData.TargetVariable, clonedTree, interpreter, estimationLimits.Lower, estimationLimits.Upper); 87 var nodes = clonedTree.Root.GetSubtree(0).GetSubtree(0).IterateNodesPrefix().ToList(); // skip the nodes corresponding to the ProgramRootSymbol and the StartSymbol 88 89 double qualityForImpactsCalculation = double.NaN; // pass a NaN value initially so the impact calculator will calculate the quality 90 91 for (int i = 0; i < nodes.Count; ++i) { 92 var node = nodes[i]; 93 if (node is ConstantTreeNode) continue; 94 95 double impactValue, replacementValue; 96 double newQualityForImpactsCalculation; 97 impactValuesCalculator.CalculateImpactAndReplacementValues(model, node, problemData, rows, out impactValue, out replacementValue, out newQualityForImpactsCalculation, qualityForImpactsCalculation); 98 99 if (pruneOnlyZeroImpactNodes && !impactValue.IsAlmost(0.0)) continue; 100 if (!pruneOnlyZeroImpactNodes && impactValue > nodeImpactThreshold) continue; 101 102 var constantNode = (ConstantTreeNode)node.Grammar.GetSymbol("Constant").CreateTreeNode(); 103 constantNode.Value = replacementValue; 104 105 ReplaceWithConstant(node, constantNode); 106 i += node.GetLength() - 1; // skip subtrees under the node that was folded 107 108 qualityForImpactsCalculation = newQualityForImpactsCalculation; 109 } 110 return model.SymbolicExpressionTree; 50 111 } 51 112 }
Note: See TracChangeset
for help on using the changeset viewer.