Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SingleObjective/SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator.cs @ 17578

Last change on this file since 17578 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 8.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using HEAL.Attic;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
31using HeuristicLab.PluginInfrastructure;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
34  [NonDiscoverableType]
35  [Item("Weighted Residuals Mean Squared Error Evaluator", @"A modified mean squared error evaluator that enables the possibility to weight residuals differently.
36The first residual category belongs to estimated values which definitely belong to a specific class because the estimated value is located above the maximum or below the minimum of all the class values (DefiniteResidualsWeight).
37The second residual category represents residuals which belong to the positive class whereby the estimated value is located between the positive and a negative class (PositiveClassResidualsWeight).
38All other cases are represented by the third category (NegativeClassesResidualsWeight).
39The weight gets multiplied to the squared error. Note that the Evaluator acts like a normal MSE-Evaluator if all the weights are set to 1.")]
40  [StorableType("A3193296-1A0F-46E2-8F43-22E2ED9CFFC5")]
41  public sealed class SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator : SymbolicClassificationSingleObjectiveEvaluator {
42    private const string DefiniteResidualsWeightParameterName = "DefiniteResidualsWeight";
43    private const string PositiveClassResidualsWeightParameterName = "PositiveClassResidualsWeight";
44    private const string NegativeClassesResidualsWeightParameterName = "NegativeClassesResidualsWeight";
45    [StorableConstructor]
46    private SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(StorableConstructorFlag _) : base(_) { }
47    private SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator original, Cloner cloner)
48      : base(original, cloner) {
49    }
50    public override IDeepCloneable Clone(Cloner cloner) {
51      return new SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(this, cloner);
52    }
53
54    public SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator()
55      : base() {
56      Parameters.Add(new FixedValueParameter<DoubleValue>(DefiniteResidualsWeightParameterName, "Weight of residuals which definitely belong to a specific class because the estimated values is located above the maximum or below the minimum of all the class values.", new DoubleValue(1)));
57      Parameters.Add(new FixedValueParameter<DoubleValue>(PositiveClassResidualsWeightParameterName, "Weight of residuals which belong to the positive class whereby the estimated value is located between the positive and a negative class.", new DoubleValue(1)));
58      Parameters.Add(new FixedValueParameter<DoubleValue>(NegativeClassesResidualsWeightParameterName, "Weight of residuals which are not covered by the DefiniteResidualsWeight or the PositiveClassResidualsWeight.", new DoubleValue(1)));
59    }
60
61    #region parameter properties
62    public IFixedValueParameter<DoubleValue> DefiniteResidualsWeightParameter {
63      get { return (IFixedValueParameter<DoubleValue>)Parameters[DefiniteResidualsWeightParameterName]; }
64    }
65    public IFixedValueParameter<DoubleValue> PositiveClassResidualsWeightParameter {
66      get { return (IFixedValueParameter<DoubleValue>)Parameters[PositiveClassResidualsWeightParameterName]; }
67    }
68    public IFixedValueParameter<DoubleValue> NegativeClassesResidualsWeightParameter {
69      get { return (IFixedValueParameter<DoubleValue>)Parameters[NegativeClassesResidualsWeightParameterName]; }
70    }
71    #endregion
72
73    #region properties
74    public override bool Maximization { get { return false; } }
75
76    public double DefiniteResidualsWeight {
77      get { return DefiniteResidualsWeightParameter.Value.Value; }
78    }
79    public double PositiveClassResidualsWeight {
80      get { return PositiveClassResidualsWeightParameter.Value.Value; }
81    }
82    public double NegativeClassesResidualsWeight {
83      get { return NegativeClassesResidualsWeightParameter.Value.Value; }
84    }
85    #endregion
86
87    public override IOperation InstrumentedApply() {
88      IEnumerable<int> rows = GenerateRowsToEvaluate();
89      var solution = SymbolicExpressionTreeParameter.ActualValue;
90      double quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, rows, ApplyLinearScalingParameter.ActualValue.Value,
91        DefiniteResidualsWeight, PositiveClassResidualsWeight, NegativeClassesResidualsWeight);
92      QualityParameter.ActualValue = new DoubleValue(quality);
93      return base.InstrumentedApply();
94    }
95
96    public static double Calculate(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, double lowerEstimationLimit, double upperEstimationLimit, IClassificationProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
97      double definiteResidualsWeight, double positiveClassResidualsWeight, double negativeClassesResidualsWeight) {
98      IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows);
99      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
100      OnlineCalculatorError errorState;
101
102      double positiveClassValue = problemData.GetClassValue(problemData.PositiveClass);
103      //get class values min/max
104      double classValuesMin = problemData.ClassValues.ElementAtOrDefault(0);
105      double classValuesMax = classValuesMin;
106      foreach (double classValue in problemData.ClassValues) {
107        if (classValuesMin > classValue) classValuesMin = classValue;
108        if (classValuesMax < classValue) classValuesMax = classValue;
109      }
110
111      double quality;
112      if (applyLinearScaling) {
113        var calculator = new OnlineWeightedClassificationMeanSquaredErrorCalculator(positiveClassValue, classValuesMax, classValuesMin,
114          definiteResidualsWeight, positiveClassResidualsWeight, negativeClassesResidualsWeight);
115        CalculateWithScaling(targetValues, estimatedValues, lowerEstimationLimit, upperEstimationLimit, calculator, problemData.Dataset.Rows);
116        errorState = calculator.ErrorState;
117        quality = calculator.WeightedResidualsMeanSquaredError;
118      } else {
119        IEnumerable<double> boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit);
120        quality = OnlineWeightedClassificationMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, positiveClassValue, classValuesMax,
121          classValuesMin, definiteResidualsWeight, positiveClassResidualsWeight, negativeClassesResidualsWeight, out errorState);
122      }
123      if (errorState != OnlineCalculatorError.None) return Double.NaN;
124      return quality;
125    }
126
127    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IClassificationProblemData problemData, IEnumerable<int> rows) {
128      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
129      EstimationLimitsParameter.ExecutionContext = context;
130      ApplyLinearScalingParameter.ExecutionContext = context;
131
132      double quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value, DefiniteResidualsWeight, PositiveClassResidualsWeight, NegativeClassesResidualsWeight);
133
134      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
135      EstimationLimitsParameter.ExecutionContext = null;
136      ApplyLinearScalingParameter.ExecutionContext = null;
137
138      return quality;
139    }
140  }
141}
Note: See TracBrowser for help on using the repository browser.