Free cookie consent management tool by TermsFeed Policy Generator

source: branches/3026_IntegrationIntoSymSpace/HeuristicLab.Problems.DataAnalysis.Symbolic.Classification/3.4/MultiObjective/SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer.cs @ 17928

Last change on this file since 17928 was 17928, checked in by dpiringe, 3 years ago

#3026

  • merged trunk into branch
File size: 10.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HEAL.Attic;
25using HeuristicLab.Analysis;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32
33namespace HeuristicLab.Problems.DataAnalysis.Symbolic.Classification {
34  /// <summary>
35  /// An operator that analyzes the training best symbolic classification solution for multi objective symbolic classification problems.
36  /// </summary>
37  [Item("SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer", "An operator that analyzes the training best symbolic classification solution for multi objective symbolic classification problems.")]
38  [StorableType("EC30DC99-A5A8-43B0-81C1-BA9016A0A74C")]
39  public sealed class SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer : SymbolicDataAnalysisMultiObjectiveTrainingBestSolutionAnalyzer<ISymbolicClassificationSolution>,
40    ISymbolicDataAnalysisInterpreterOperator, ISymbolicDataAnalysisBoundedOperator, ISymbolicClassificationModelCreatorOperator {
41    private const string ProblemDataParameterName = "ProblemData";
42    private const string ModelCreatorParameterName = "ModelCreator";
43    private const string SymbolicDataAnalysisTreeInterpreterParameterName = "SymbolicDataAnalysisTreeInterpreter";
44    private const string EstimationLimitsParameterName = "EstimationLimits";
45    private const string MaximumSymbolicExpressionTreeLengthParameterName = "MaximumSymbolicExpressionTreeLength";
46    private const string ValidationPartitionParameterName = "ValidationPartition";
47    private const string AnalyzeTestErrorParameterName = "Analyze Test Error";
48
49    #region parameter properties
50    public ILookupParameter<IClassificationProblemData> ProblemDataParameter {
51      get { return (ILookupParameter<IClassificationProblemData>)Parameters[ProblemDataParameterName]; }
52    }
53    public IValueLookupParameter<ISymbolicClassificationModelCreator> ModelCreatorParameter {
54      get { return (IValueLookupParameter<ISymbolicClassificationModelCreator>)Parameters[ModelCreatorParameterName]; }
55    }
56    ILookupParameter<ISymbolicClassificationModelCreator> ISymbolicClassificationModelCreatorOperator.ModelCreatorParameter {
57      get { return ModelCreatorParameter; }
58    }
59    public ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter> SymbolicDataAnalysisTreeInterpreterParameter {
60      get { return (ILookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>)Parameters[SymbolicDataAnalysisTreeInterpreterParameterName]; }
61    }
62    public IValueLookupParameter<DoubleLimit> EstimationLimitsParameter {
63      get { return (IValueLookupParameter<DoubleLimit>)Parameters[EstimationLimitsParameterName]; }
64    }
65    public ILookupParameter<IntValue> MaximumSymbolicExpressionTreeLengthParameter {
66      get { return (ILookupParameter<IntValue>)Parameters[MaximumSymbolicExpressionTreeLengthParameterName]; }
67    }
68    public IValueLookupParameter<IntRange> ValidationPartitionParameter {
69      get { return (IValueLookupParameter<IntRange>)Parameters[ValidationPartitionParameterName]; }
70    }
71    public IFixedValueParameter<BoolValue> AnalyzeTestErrorParameter {
72      get { return (IFixedValueParameter<BoolValue>)Parameters[AnalyzeTestErrorParameterName]; }
73    }
74    public bool AnalyzeTestError {
75      get { return AnalyzeTestErrorParameter.Value.Value; }
76      set { AnalyzeTestErrorParameter.Value.Value = value; }
77    }
78    #endregion
79
80    [StorableConstructor]
81    private SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer(StorableConstructorFlag _) : base(_) { }
82    private SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer(SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer original, Cloner cloner) : base(original, cloner) { }
83    public SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer()
84      : base() {
85      Parameters.Add(new LookupParameter<IClassificationProblemData>(ProblemDataParameterName, "The problem data for the symbolic classification solution."));
86      Parameters.Add(new ValueLookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName, ""));
87      Parameters.Add(new LookupParameter<ISymbolicDataAnalysisExpressionTreeInterpreter>(SymbolicDataAnalysisTreeInterpreterParameterName, "The symbolic data analysis tree interpreter for the symbolic expression tree."));
88      Parameters.Add(new ValueLookupParameter<DoubleLimit>(EstimationLimitsParameterName, "The lower and upper limit for the estimated values produced by the symbolic classification model."));
89      Parameters.Add(new LookupParameter<IntValue>(MaximumSymbolicExpressionTreeLengthParameterName, "Maximal length of the symbolic expression.") { Hidden = true });
90      Parameters.Add(new ValueLookupParameter<IntRange>(ValidationPartitionParameterName, "The validation partition."));
91      Parameters.Add(new FixedValueParameter<BoolValue>(AnalyzeTestErrorParameterName, "Flag whether the test error should be displayed in the Pareto-Front", new BoolValue(false)));
92
93    }
94    public override IDeepCloneable Clone(Cloner cloner) {
95      return new SymbolicClassificationMultiObjectiveTrainingBestSolutionAnalyzer(this, cloner);
96    }
97
98    [StorableHook(HookType.AfterDeserialization)]
99    private void AfterDeserialization() {
100      // BackwardsCompatibility3.4
101      #region Backwards compatible code, remove with 3.5
102      if (!Parameters.ContainsKey(ModelCreatorParameterName))
103        Parameters.Add(new ValueLookupParameter<ISymbolicClassificationModelCreator>(ModelCreatorParameterName, ""));
104      #endregion
105    }
106
107    protected override ISymbolicClassificationSolution CreateSolution(ISymbolicExpressionTree bestTree, double[] bestQuality) {
108      var model = ModelCreatorParameter.ActualValue.CreateSymbolicClassificationModel(ProblemDataParameter.ActualValue.TargetVariable, (ISymbolicExpressionTree)bestTree.Clone(), SymbolicDataAnalysisTreeInterpreterParameter.ActualValue, EstimationLimitsParameter.ActualValue.Lower, EstimationLimitsParameter.ActualValue.Upper);
109      if (ApplyLinearScalingParameter.ActualValue.Value) model.Scale(ProblemDataParameter.ActualValue);
110
111      model.RecalculateModelParameters(ProblemDataParameter.ActualValue, ProblemDataParameter.ActualValue.TrainingIndices);
112      return model.CreateClassificationSolution((IClassificationProblemData)ProblemDataParameter.ActualValue.Clone());
113    }
114
115    public override IOperation Apply() {
116      var operation = base.Apply();
117      var paretoFront = TrainingBestSolutionsParameter.ActualValue;
118
119      IResult result;
120      ScatterPlot qualityToTreeSize;
121      if (!ResultCollection.TryGetValue("Pareto Front Analysis", out result)) {
122        qualityToTreeSize = new ScatterPlot("Quality vs Tree Size", "");
123        qualityToTreeSize.VisualProperties.XAxisMinimumAuto = false;
124        qualityToTreeSize.VisualProperties.XAxisMaximumAuto = false;
125        qualityToTreeSize.VisualProperties.YAxisMinimumAuto = false;
126        qualityToTreeSize.VisualProperties.YAxisMaximumAuto = false;
127
128        qualityToTreeSize.VisualProperties.XAxisMinimumFixedValue = 0;
129        qualityToTreeSize.VisualProperties.XAxisMaximumFixedValue = MaximumSymbolicExpressionTreeLengthParameter.ActualValue.Value;
130        qualityToTreeSize.VisualProperties.YAxisMinimumFixedValue = 0;
131        qualityToTreeSize.VisualProperties.YAxisMaximumFixedValue = 1;
132        ResultCollection.Add(new Result("Pareto Front Analysis", qualityToTreeSize));
133      } else {
134        qualityToTreeSize = (ScatterPlot)result.Value;
135      }
136
137      int previousTreeLength = -1;
138      var sizeParetoFront = new LinkedList<ISymbolicClassificationSolution>();
139      foreach (var solution in paretoFront.OrderBy(s => s.Model.SymbolicExpressionTree.Length)) {
140        int treeLength = solution.Model.SymbolicExpressionTree.Length;
141        if (!sizeParetoFront.Any()) sizeParetoFront.AddLast(solution);
142        if (solution.TrainingAccuracy > sizeParetoFront.Last.Value.TrainingAccuracy) {
143          if (treeLength == previousTreeLength)
144            sizeParetoFront.RemoveLast();
145          sizeParetoFront.AddLast(solution);
146        }
147        previousTreeLength = treeLength;
148      }
149
150      qualityToTreeSize.Rows.Clear();
151      var trainingRow = new ScatterPlotDataRow("Training Accuracy", "", sizeParetoFront.Select(x => new Point2D<double>(x.Model.SymbolicExpressionTree.Length, x.TrainingAccuracy, x)));
152      trainingRow.VisualProperties.PointSize = 8;
153      qualityToTreeSize.Rows.Add(trainingRow);
154
155      if (AnalyzeTestError) {
156        var testRow = new ScatterPlotDataRow("Test Accuracy", "",
157          sizeParetoFront.Select(x => new Point2D<double>(x.Model.SymbolicExpressionTree.Length, x.TestAccuracy, x)));
158        testRow.VisualProperties.PointSize = 8;
159        qualityToTreeSize.Rows.Add(testRow);
160      }
161
162      var validationPartition = ValidationPartitionParameter.ActualValue;
163      if (validationPartition.Size != 0) {
164        var problemData = ProblemDataParameter.ActualValue;
165        var validationIndizes = Enumerable.Range(validationPartition.Start, validationPartition.Size).ToList();
166        var targetValues = problemData.Dataset.GetDoubleValues(problemData.TargetVariable, validationIndizes).ToList();
167        OnlineCalculatorError error;
168        var validationRow = new ScatterPlotDataRow("Validation Accuracy", "",
169          sizeParetoFront.Select(x => new Point2D<double>(x.Model.SymbolicExpressionTree.Length,
170          OnlineAccuracyCalculator.Calculate(targetValues, x.GetEstimatedClassValues(validationIndizes), out error))));
171        validationRow.VisualProperties.PointSize = 7;
172        qualityToTreeSize.Rows.Add(validationRow);
173      }
174
175      return operation;
176    }
177  }
178}
Note: See TracBrowser for help on using the repository browser.