Free cookie consent management tool by TermsFeed Policy Generator

source: stable/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/SingleObjective/SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator.cs @ 17912

Last change on this file since 17912 was 17181, checked in by swagner, 5 years ago

#2875: Merged r17180 from trunk to stable

File size: 8.9 KB
RevLine 
[12416]1#region License Information
2/* HeuristicLab
[17181]3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
[12416]4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
[16788]25using HEAL.Attic;
[12416]26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Parameters;
[16788]31using HeuristicLab.PluginInfrastructure;
[12416]32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
[16788]34  [NonDiscoverableType]
[12416]35  [Item("Weighted Residuals Mean Squared Error Evaluator", @"A modified mean squared error evaluator that enables the possibility to weight residuals differently.
[12448]36The first residual category belongs to estimated values which definitely belong to a specific class because the estimated value is located above the maximum or below the minimum of all the class values (DefiniteResidualsWeight).
[12416]37The second residual category represents residuals which belong to the positive class whereby the estimated value is located between the positive and a negative class (PositiveClassResidualsWeight).
[12448]38All other cases are represented by the third category (NegativeClassesResidualsWeight).
[12449]39The weight gets multiplied to the squared error. Note that the Evaluator acts like a normal MSE-Evaluator if all the weights are set to 1.")]
[16788]40  [StorableType("A3193296-1A0F-46E2-8F43-22E2ED9CFFC5")]
41  public sealed class SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator : SymbolicClassificationSingleObjectiveEvaluator {
[12448]42    private const string DefiniteResidualsWeightParameterName = "DefiniteResidualsWeight";
[12416]43    private const string PositiveClassResidualsWeightParameterName = "PositiveClassResidualsWeight";
[12448]44    private const string NegativeClassesResidualsWeightParameterName = "NegativeClassesResidualsWeight";
[12416]45    [StorableConstructor]
[16788]46    private SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(StorableConstructorFlag _) : base(_) { }
47    private SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator original, Cloner cloner)
[12416]48      : base(original, cloner) {
49    }
50    public override IDeepCloneable Clone(Cloner cloner) {
51      return new SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator(this, cloner);
52    }
53
54    public SymbolicClassificationSingleObjectiveWeightedResidualsMeanSquaredErrorEvaluator()
55      : base() {
[12448]56      Parameters.Add(new FixedValueParameter<DoubleValue>(DefiniteResidualsWeightParameterName, "Weight of residuals which definitely belong to a specific class because the estimated values is located above the maximum or below the minimum of all the class values.", new DoubleValue(1)));
[12416]57      Parameters.Add(new FixedValueParameter<DoubleValue>(PositiveClassResidualsWeightParameterName, "Weight of residuals which belong to the positive class whereby the estimated value is located between the positive and a negative class.", new DoubleValue(1)));
[12448]58      Parameters.Add(new FixedValueParameter<DoubleValue>(NegativeClassesResidualsWeightParameterName, "Weight of residuals which are not covered by the DefiniteResidualsWeight or the PositiveClassResidualsWeight.", new DoubleValue(1)));
[12416]59    }
60
61    #region parameter properties
[12448]62    public IFixedValueParameter<DoubleValue> DefiniteResidualsWeightParameter {
63      get { return (IFixedValueParameter<DoubleValue>)Parameters[DefiniteResidualsWeightParameterName]; }
[12416]64    }
65    public IFixedValueParameter<DoubleValue> PositiveClassResidualsWeightParameter {
66      get { return (IFixedValueParameter<DoubleValue>)Parameters[PositiveClassResidualsWeightParameterName]; }
67    }
[12448]68    public IFixedValueParameter<DoubleValue> NegativeClassesResidualsWeightParameter {
69      get { return (IFixedValueParameter<DoubleValue>)Parameters[NegativeClassesResidualsWeightParameterName]; }
[12416]70    }
71    #endregion
72
73    #region properties
74    public override bool Maximization { get { return false; } }
75
[12448]76    public double DefiniteResidualsWeight {
[16788]77      get { return DefiniteResidualsWeightParameter.Value.Value; }
[12416]78    }
79    public double PositiveClassResidualsWeight {
[16788]80      get { return PositiveClassResidualsWeightParameter.Value.Value; }
[12416]81    }
[12448]82    public double NegativeClassesResidualsWeight {
[16788]83      get { return NegativeClassesResidualsWeightParameter.Value.Value; }
[12416]84    }
85    #endregion
86
87    public override IOperation InstrumentedApply() {
88      IEnumerable<int> rows = GenerateRowsToEvaluate();
89      var solution = SymbolicExpressionTreeParameter.ActualValue;
90      double quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, solution, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, ProblemDataParameter.ActualValue, rows, ApplyLinearScalingParameter.ActualValue.Value,
[12448]91        DefiniteResidualsWeight, PositiveClassResidualsWeight, NegativeClassesResidualsWeight);
[12416]92      QualityParameter.ActualValue = new DoubleValue(quality);
93      return base.InstrumentedApply();
94    }
95
[16788]96    public static double Calculate(ISymbolicDataAnalysisExpressionTreeInterpreter interpreter, ISymbolicExpressionTree tree, double lowerEstimationLimit, double upperEstimationLimit, IClassificationProblemData problemData, IEnumerable<int> rows, bool applyLinearScaling,
[12448]97      double definiteResidualsWeight, double positiveClassResidualsWeight, double negativeClassesResidualsWeight) {
[16788]98      IEnumerable<double> estimatedValues = interpreter.GetSymbolicExpressionTreeValues(tree, problemData.Dataset, rows);
[12416]99      IEnumerable<double> targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, rows);
100      OnlineCalculatorError errorState;
101
102      double positiveClassValue = problemData.GetClassValue(problemData.PositiveClass);
103      //get class values min/max
104      double classValuesMin = problemData.ClassValues.ElementAtOrDefault(0);
105      double classValuesMax = classValuesMin;
106      foreach (double classValue in problemData.ClassValues) {
107        if (classValuesMin > classValue) classValuesMin = classValue;
108        if (classValuesMax < classValue) classValuesMax = classValue;
109      }
110
111      double quality;
112      if (applyLinearScaling) {
[16788]113        var calculator = new OnlineWeightedClassificationMeanSquaredErrorCalculator(positiveClassValue, classValuesMax, classValuesMin,
[12448]114          definiteResidualsWeight, positiveClassResidualsWeight, negativeClassesResidualsWeight);
[12416]115        CalculateWithScaling(targetValues, estimatedValues, lowerEstimationLimit, upperEstimationLimit, calculator, problemData.Dataset.Rows);
116        errorState = calculator.ErrorState;
117        quality = calculator.WeightedResidualsMeanSquaredError;
118      } else {
119        IEnumerable<double> boundedEstimatedValues = estimatedValues.LimitToRange(lowerEstimationLimit, upperEstimationLimit);
[16788]120        quality = OnlineWeightedClassificationMeanSquaredErrorCalculator.Calculate(targetValues, boundedEstimatedValues, positiveClassValue, classValuesMax,
[12448]121          classValuesMin, definiteResidualsWeight, positiveClassResidualsWeight, negativeClassesResidualsWeight, out errorState);
[12416]122      }
123      if (errorState != OnlineCalculatorError.None) return Double.NaN;
124      return quality;
125    }
126
127    public override double Evaluate(IExecutionContext context, ISymbolicExpressionTree tree, IClassificationProblemData problemData, IEnumerable<int> rows) {
128      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = context;
129      EstimationLimitsParameter.ExecutionContext = context;
130      ApplyLinearScalingParameter.ExecutionContext = context;
131
[12448]132      double quality = Calculate(SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, tree, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper, problemData, rows, ApplyLinearScalingParameter.ActualValue.Value, DefiniteResidualsWeight, PositiveClassResidualsWeight, NegativeClassesResidualsWeight);
[12416]133
134      SymbolicDataAnalysisTreeInterpreterParameter.ExecutionContext = null;
135      EstimationLimitsParameter.ExecutionContext = null;
136      ApplyLinearScalingParameter.ExecutionContext = null;
137
138      return quality;
139    }
140  }
141}
Note: See TracBrowser for help on using the repository browser.