Free cookie consent management tool by TermsFeed Policy Generator

source: branches/ParameterBinding/HeuristicLab.Problems.DataAnalysis.Classification/3.3/Symbolic/Analyzer/ValidationBestSymbolicClassificationSolutionAnalyzer.cs @ 10204

Last change on this file since 10204 was 4722, checked in by swagner, 14 years ago

Merged cloning refactoring branch back into trunk (#922)

File size: 17.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System.Collections.Generic;
23using System.Linq;
24using HeuristicLab.Analysis;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Data;
28using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
29using HeuristicLab.Operators;
30using HeuristicLab.Optimization;
31using HeuristicLab.Parameters;
32using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
33using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic;
34using HeuristicLab.Problems.DataAnalysis.Regression.Symbolic.Analyzers;
35using HeuristicLab.Problems.DataAnalysis.Symbolic;
36
37namespace HeuristicLab.Problems.DataAnalysis.Classification {
38  [Item("ValidationBestSymbolicClassificationSolutionAnalyzer", "An operator that analyzes the validation best symbolic classification solution.")]
39  [StorableClass]
40  public class ValidationBestSymbolicClassificationSolutionAnalyzer : SingleSuccessorOperator, ISymbolicClassificationAnalyzer {
41    private const string MaximizationParameterName = "Maximization";
42    private const string GenerationsParameterName = "Generations";
43    private const string RandomParameterName = "Random";
44    private const string SymbolicExpressionTreeParameterName = "SymbolicExpressionTree";
45    private const string SymbolicExpressionTreeInterpreterParameterName = "SymbolicExpressionTreeInterpreter";
46
47    private const string ClassificationProblemDataParameterName = "ClassificationProblemData";
48    private const string EvaluatorParameterName = "Evaluator";
49    private const string ValidationSamplesStartParameterName = "SamplesStart";
50    private const string ValidationSamplesEndParameterName = "SamplesEnd";
51    private const string RelativeNumberOfEvaluatedSamplesParameterName = "RelativeNumberOfEvaluatedSamples";
52    private const string UpperEstimationLimitParameterName = "UpperEstimationLimit";
53    private const string LowerEstimationLimitParameterName = "LowerEstimationLimit";
54
55    private const string ResultsParameterName = "Results";
56    private const string BestValidationQualityParameterName = "Best validation quality";
57    private const string BestValidationSolutionParameterName = "Best validation solution";
58    private const string BestSolutionAccuracyTrainingParameterName = "Best solution accuracy (training)";
59    private const string BestSolutionAccuracyTestParameterName = "Best solution accuracy (test)";
60    private const string VariableFrequenciesParameterName = "VariableFrequencies";
61
62    #region parameter properties
63    public ILookupParameter<BoolValue> MaximizationParameter {
64      get { return (ILookupParameter<BoolValue>)Parameters[MaximizationParameterName]; }
65    }
66    public ILookupParameter<IntValue> GenerationsParameter {
67      get { return (ILookupParameter<IntValue>)Parameters[GenerationsParameterName]; }
68    }
69    public ILookupParameter<IRandom> RandomParameter {
70      get { return (ILookupParameter<IRandom>)Parameters[RandomParameterName]; }
71    }
72    public ScopeTreeLookupParameter<SymbolicExpressionTree> SymbolicExpressionTreeParameter {
73      get { return (ScopeTreeLookupParameter<SymbolicExpressionTree>)Parameters[SymbolicExpressionTreeParameterName]; }
74    }
75    public IValueLookupParameter<ISymbolicExpressionTreeInterpreter> SymbolicExpressionTreeInterpreterParameter {
76      get { return (IValueLookupParameter<ISymbolicExpressionTreeInterpreter>)Parameters[SymbolicExpressionTreeInterpreterParameterName]; }
77    }
78
79    public ILookupParameter<ClassificationProblemData> ClassificationProblemDataParameter {
80      get { return (ILookupParameter<ClassificationProblemData>)Parameters[ClassificationProblemDataParameterName]; }
81    }
82    public ILookupParameter<ISymbolicClassificationEvaluator> EvaluatorParameter {
83      get { return (ILookupParameter<ISymbolicClassificationEvaluator>)Parameters[EvaluatorParameterName]; }
84    }
85    public IValueLookupParameter<IntValue> ValidationSamplesStartParameter {
86      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesStartParameterName]; }
87    }
88    public IValueLookupParameter<IntValue> ValidationSamplesEndParameter {
89      get { return (IValueLookupParameter<IntValue>)Parameters[ValidationSamplesEndParameterName]; }
90    }
91    public IValueParameter<PercentValue> RelativeNumberOfEvaluatedSamplesParameter {
92      get { return (IValueParameter<PercentValue>)Parameters[RelativeNumberOfEvaluatedSamplesParameterName]; }
93    }
94    public IValueLookupParameter<DoubleValue> UpperEstimationLimitParameter {
95      get { return (IValueLookupParameter<DoubleValue>)Parameters[UpperEstimationLimitParameterName]; }
96    }
97    public IValueLookupParameter<DoubleValue> LowerEstimationLimitParameter {
98      get { return (IValueLookupParameter<DoubleValue>)Parameters[LowerEstimationLimitParameterName]; }
99    }
100    public ILookupParameter<DataTable> VariableFrequenciesParameter {
101      get { return (ILookupParameter<DataTable>)Parameters[VariableFrequenciesParameterName]; }
102    }
103
104    public ILookupParameter<ResultCollection> ResultsParameter {
105      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
106    }
107    public ILookupParameter<DoubleValue> BestValidationQualityParameter {
108      get { return (ILookupParameter<DoubleValue>)Parameters[BestValidationQualityParameterName]; }
109    }
110    public ILookupParameter<SymbolicClassificationSolution> BestValidationSolutionParameter {
111      get { return (ILookupParameter<SymbolicClassificationSolution>)Parameters[BestValidationSolutionParameterName]; }
112    }
113    public ILookupParameter<DoubleValue> BestSolutionAccuracyTrainingParameter {
114      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionAccuracyTrainingParameterName]; }
115    }
116    public ILookupParameter<DoubleValue> BestSolutionAccuracyTestParameter {
117      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionAccuracyTestParameterName]; }
118    }
119    #endregion
120    #region properties
121    public BoolValue Maximization {
122      get { return MaximizationParameter.ActualValue; }
123    }
124    public IntValue Generations {
125      get { return GenerationsParameter.ActualValue; }
126    }
127    public IRandom Random {
128      get { return RandomParameter.ActualValue; }
129    }
130    public ItemArray<SymbolicExpressionTree> SymbolicExpressionTree {
131      get { return SymbolicExpressionTreeParameter.ActualValue; }
132    }
133    public ISymbolicExpressionTreeInterpreter SymbolicExpressionTreeInterpreter {
134      get { return SymbolicExpressionTreeInterpreterParameter.ActualValue; }
135    }
136
137    public ClassificationProblemData ClassificationProblemData {
138      get { return ClassificationProblemDataParameter.ActualValue; }
139    }
140    public ISymbolicClassificationEvaluator Evaluator {
141      get { return EvaluatorParameter.ActualValue; }
142    }
143    public IntValue ValidiationSamplesStart {
144      get { return ValidationSamplesStartParameter.ActualValue; }
145    }
146    public IntValue ValidationSamplesEnd {
147      get { return ValidationSamplesEndParameter.ActualValue; }
148    }
149    public PercentValue RelativeNumberOfEvaluatedSamples {
150      get { return RelativeNumberOfEvaluatedSamplesParameter.Value; }
151    }
152    public DoubleValue UpperEstimationLimit {
153      get { return UpperEstimationLimitParameter.ActualValue; }
154    }
155    public DoubleValue LowerEstimationLimit {
156      get { return LowerEstimationLimitParameter.ActualValue; }
157    }
158    public DataTable VariableFrequencies {
159      get { return VariableFrequenciesParameter.ActualValue; }
160    }
161
162    public ResultCollection Results {
163      get { return ResultsParameter.ActualValue; }
164    }
165    public DoubleValue BestValidationQuality {
166      get { return BestValidationQualityParameter.ActualValue; }
167      protected set { BestValidationQualityParameter.ActualValue = value; }
168    }
169    public SymbolicClassificationSolution BestValidationSolution {
170      get { return BestValidationSolutionParameter.ActualValue; }
171      protected set { BestValidationSolutionParameter.ActualValue = value; }
172    }
173    public DoubleValue BestSolutionAccuracyTraining {
174      get { return BestSolutionAccuracyTrainingParameter.ActualValue; }
175      protected set { BestSolutionAccuracyTrainingParameter.ActualValue = value; }
176    }
177    public DoubleValue BestSolutionAccuracyTest {
178      get { return BestSolutionAccuracyTestParameter.ActualValue; }
179      protected set { BestSolutionAccuracyTestParameter.ActualValue = value; }
180    }
181    #endregion
182
183    [StorableConstructor]
184    protected ValidationBestSymbolicClassificationSolutionAnalyzer(bool deserializing) : base(deserializing) { }
185    protected ValidationBestSymbolicClassificationSolutionAnalyzer(ValidationBestSymbolicClassificationSolutionAnalyzer original, Cloner cloner)
186      : base(original, cloner) {
187    }
188    public ValidationBestSymbolicClassificationSolutionAnalyzer()
189      : base() {
190      Parameters.Add(new LookupParameter<BoolValue>(MaximizationParameterName, "The direction of optimization."));
191      Parameters.Add(new LookupParameter<IntValue>(GenerationsParameterName, "The number of generations calculated so far."));
192      Parameters.Add(new LookupParameter<IRandom>(RandomParameterName, "The random generator to use."));
193      Parameters.Add(new ScopeTreeLookupParameter<SymbolicExpressionTree>(SymbolicExpressionTreeParameterName, "The symbolic expression trees to analyze."));
194      Parameters.Add(new ValueLookupParameter<ISymbolicExpressionTreeInterpreter>(SymbolicExpressionTreeInterpreterParameterName, "The interpreter that should be used for the analysis of symbolic expression trees."));
195
196      Parameters.Add(new LookupParameter<ClassificationProblemData>(ClassificationProblemDataParameterName, "The problem data for which the symbolic expression tree is a solution."));
197      Parameters.Add(new LookupParameter<ISymbolicClassificationEvaluator>(EvaluatorParameterName, "The evaluator which should be used to evaluate the solution on the validation set."));
198      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesStartParameterName, "The first index of the validation partition of the data set."));
199      Parameters.Add(new ValueLookupParameter<IntValue>(ValidationSamplesEndParameterName, "The last index of the validation partition of the data set."));
200      Parameters.Add(new ValueParameter<PercentValue>(RelativeNumberOfEvaluatedSamplesParameterName, "The relative number of samples of the dataset partition, which should be randomly chosen for evaluation between the start and end index.", new PercentValue(1)));
201      Parameters.Add(new ValueLookupParameter<DoubleValue>(UpperEstimationLimitParameterName, "The upper estimation limit that was set for the evaluation of the symbolic expression trees."));
202      Parameters.Add(new ValueLookupParameter<DoubleValue>(LowerEstimationLimitParameterName, "The lower estimation limit that was set for the evaluation of the symbolic expression trees."));
203      Parameters.Add(new LookupParameter<DataTable>(VariableFrequenciesParameterName, "The variable frequencies table to use for the calculation of variable impacts"));
204
205      Parameters.Add(new ValueLookupParameter<ResultCollection>(ResultsParameterName, "The results collection where the analysis values should be stored."));
206      Parameters.Add(new LookupParameter<DoubleValue>(BestValidationQualityParameterName, "The validation quality of the best solution in the current run."));
207      Parameters.Add(new LookupParameter<SymbolicClassificationSolution>(BestValidationSolutionParameterName, "The best solution on the validation data found in the current run."));
208      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionAccuracyTrainingParameterName, "The training accuracy of the best solution."));
209      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionAccuracyTestParameterName, "The test accuracy of the best solution."));
210    }
211
212    public override IDeepCloneable Clone(Cloner cloner) {
213      return new ValidationBestSymbolicClassificationSolutionAnalyzer(this, cloner);
214    }
215
216    public override IOperation Apply() {
217      var trees = SymbolicExpressionTree;
218      string targetVariable = ClassificationProblemData.TargetVariable.Value;
219
220      // select a random subset of rows in the validation set
221      int validationStart = ValidiationSamplesStart.Value;
222      int validationEnd = ValidationSamplesEnd.Value;
223      int seed = Random.Next();
224      int count = (int)((validationEnd - validationStart) * RelativeNumberOfEvaluatedSamples.Value);
225      if (count == 0) count = 1;
226      IEnumerable<int> rows = RandomEnumerable.SampleRandomNumbers(seed, validationStart, validationEnd, count)
227         .Where(row => row < ClassificationProblemData.TestSamplesStart.Value || ClassificationProblemData.TestSamplesEnd.Value <= row);
228
229      double upperEstimationLimit = UpperEstimationLimit != null ? UpperEstimationLimit.Value : double.PositiveInfinity;
230      double lowerEstimationLimit = LowerEstimationLimit != null ? LowerEstimationLimit.Value : double.NegativeInfinity;
231
232      double bestQuality = Maximization.Value ? double.NegativeInfinity : double.PositiveInfinity;
233      SymbolicExpressionTree bestTree = null;
234
235      foreach (var tree in trees) {
236        double quality = Evaluator.Evaluate(SymbolicExpressionTreeInterpreter, tree,
237          lowerEstimationLimit, upperEstimationLimit, ClassificationProblemData.Dataset,
238          targetVariable, rows);
239
240        if ((Maximization.Value && quality > bestQuality) ||
241            (!Maximization.Value && quality < bestQuality)) {
242          bestQuality = quality;
243          bestTree = tree;
244        }
245      }
246
247      // if the best validation tree is better than the current best solution => update
248      bool newBest =
249        BestValidationQuality == null ||
250        (Maximization.Value && bestQuality > BestValidationQuality.Value) ||
251        (!Maximization.Value && bestQuality < BestValidationQuality.Value);
252      if (newBest) {
253        double alpha, beta;
254        SymbolicRegressionScaledMeanSquaredErrorEvaluator.Calculate(SymbolicExpressionTreeInterpreter, bestTree,
255          lowerEstimationLimit, upperEstimationLimit,
256          ClassificationProblemData.Dataset, targetVariable,
257          ClassificationProblemData.TrainingIndizes, out beta, out alpha);
258
259        // scale tree for solution
260        var scaledTree = SymbolicRegressionSolutionLinearScaler.Scale(bestTree, alpha, beta);
261        var model = new SymbolicRegressionModel((ISymbolicExpressionTreeInterpreter)SymbolicExpressionTreeInterpreter.Clone(),
262          scaledTree);
263
264        if (BestValidationSolution == null) {
265          BestValidationSolution = new SymbolicClassificationSolution(ClassificationProblemData, model, LowerEstimationLimit.Value, UpperEstimationLimit.Value);
266          BestValidationSolution.Name = BestValidationSolutionParameterName;
267          BestValidationSolution.Description = "Best solution on validation partition found over the whole run.";
268          BestValidationQuality = new DoubleValue(bestQuality);
269        } else {
270          BestValidationSolution.Model = model;
271        }
272
273        UpdateBestSolutionResults();
274      }
275      return base.Apply();
276    }
277
278    private void UpdateBestSolutionResults() {
279      BestSymbolicRegressionSolutionAnalyzer.UpdateBestSolutionResults(BestValidationSolution, ClassificationProblemData, Results, Generations, VariableFrequencies);
280
281      IEnumerable<double> trainingValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
282        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TrainingIndizes);
283      IEnumerable<double> testValues = ClassificationProblemData.Dataset.GetEnumeratedVariableValues(
284        ClassificationProblemData.TargetVariable.Value, ClassificationProblemData.TestIndizes);
285
286      OnlineAccuracyEvaluator accuracyEvaluator = new OnlineAccuracyEvaluator();
287      var originalEnumerator = trainingValues.GetEnumerator();
288      var estimatedEnumerator = BestValidationSolution.EstimatedTrainingClassValues.GetEnumerator();
289      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
290        accuracyEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
291      }
292      double trainingAccuracy = accuracyEvaluator.Accuracy;
293
294      accuracyEvaluator.Reset();
295      originalEnumerator = testValues.GetEnumerator();
296      estimatedEnumerator = BestValidationSolution.EstimatedTestClassValues.GetEnumerator();
297      while (originalEnumerator.MoveNext() & estimatedEnumerator.MoveNext()) {
298        accuracyEvaluator.Add(originalEnumerator.Current, estimatedEnumerator.Current);
299      }
300      double testAccuracy = accuracyEvaluator.Accuracy;
301
302      if (!Results.ContainsKey(BestSolutionAccuracyTrainingParameterName)) {
303        BestSolutionAccuracyTraining = new DoubleValue(trainingAccuracy);
304        BestSolutionAccuracyTest = new DoubleValue(testAccuracy);
305
306        Results.Add(new Result(BestSolutionAccuracyTrainingParameterName, BestSolutionAccuracyTraining));
307        Results.Add(new Result(BestSolutionAccuracyTestParameterName, BestSolutionAccuracyTest));
308      } else {
309        BestSolutionAccuracyTraining.Value = trainingAccuracy;
310        BestSolutionAccuracyTest.Value = testAccuracy;
311      }
312    }
313  }
314}
Note: See TracBrowser for help on using the repository browser.