source: branches/DataPreprocessing/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SymbolicClassificationPruningOperator.cs @ 10538

Last change on this file since 10538 was 10538, checked in by pfleck, 6 years ago
  • merged trunk
File size: 3.4 KB
Line 
1using System.Linq;
2using HeuristicLab.Common;
3using HeuristicLab.Core;
4using HeuristicLab.Data;
5using HeuristicLab.Parameters;
6using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
7
8namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
9  [StorableClass]
10  [Item("SymbolicClassificationPruningOperator", "An operator which prunes symbolic classificaton trees.")]
11  public class SymbolicClassificationPruningOperator : SymbolicDataAnalysisExpressionPruningOperator {
12    private const string ImpactValuesCalculatorParameterName = "ImpactValuesCalculator";
13    private const string ModelCreatorParameterName = "ModelCreator";
14    private const string ApplyLinearScalingParmameterName = "ApplyLinearScaling";
15
16    #region parameter properties
17    public ILookupParameter<ISymbolicClassificationModelCreator> ModelCreatorParameter {
18      get { return (ILookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; }
19    }
20
21    public ILookupParameter<BoolValue> ApplyLinearScalingParameter {
22      get { return (ILookupParameter<BoolValue>)Parameters[ApplyLinearScalingParmameterName]; }
23    }
24    #endregion
25    #region properties
26    private ISymbolicClassificationModelCreator ModelCreator { get { return ModelCreatorParameter.ActualValue; } }
27    private BoolValue ApplyLinearScaling { get { return ApplyLinearScalingParameter.ActualValue; } }
28    #endregion
29
30    protected SymbolicClassificationPruningOperator(SymbolicClassificationPruningOperator original, Cloner cloner)
31      : base(original, cloner) {
32    }
33    public override IDeepCloneable Clone(Cloner cloner) {
34      return new SymbolicClassificationPruningOperator(this, cloner);
35    }
36
37    [StorableConstructor]
38    protected SymbolicClassificationPruningOperator(bool deserializing) : base(deserializing) { }
39
40    public SymbolicClassificationPruningOperator() {
41      Parameters.Add(new ValueParameter<ISymbolicDataAnalysisSolutionImpactValuesCalculator>(ImpactValuesCalculatorParameterName, new SymbolicClassificationSolutionImpactValuesCalculator()));
42      Parameters.Add(new LookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName));
43    }
44
45    protected override ISymbolicDataAnalysisModel CreateModel() {
46      var model = ModelCreator.CreateSymbolicClassificationModel(SymbolicExpressionTree, Interpreter, EstimationLimits.Lower, EstimationLimits.Upper);
47      var rows = Enumerable.Range(FitnessCalculationPartition.Start, FitnessCalculationPartition.Size);
48      var problemData = (IClassificationProblemData)ProblemData;
49      model.RecalculateModelParameters(problemData, rows);
50      return model;
51    }
52
53    protected override double Evaluate(IDataAnalysisModel model) {
54      var classificationModel = (IClassificationModel)model;
55      var classificationProblemData = (IClassificationProblemData)ProblemData;
56      var trainingIndices = ProblemData.TrainingIndices.ToList();
57      var estimatedValues = classificationModel.GetEstimatedClassValues(ProblemData.Dataset, trainingIndices);
58      var targetValues = ProblemData.Dataset.GetDoubleValues(classificationProblemData.TargetVariable, trainingIndices);
59      OnlineCalculatorError errorState;
60      var quality = OnlinePearsonsRSquaredCalculator.Calculate(targetValues, estimatedValues, out errorState);
61      if (errorState != OnlineCalculatorError.None) return double.NaN;
62      return quality;
63    }
64  }
65}
Note: See TracBrowser for help on using the repository browser.