Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.Problems.DataAnalysis/3.3/SupportVectorMachine/ParameterAdjustmentProblem/SupportVectorMachineParameterAdjustmentBestSolutionAnalyzer.cs @ 3884

Last change on this file since 3884 was 3884, checked in by gkronber, 14 years ago

Worked on support vector regression operators and views. #1009

File size: 10.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2010 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Linq;
24using HeuristicLab.Common;
25using HeuristicLab.Core;
26using HeuristicLab.Data;
27using HeuristicLab.Encodings.RealVectorEncoding;
28using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
29using HeuristicLab.Problems.DataAnalysis.SupportVectorMachine;
30using HeuristicLab.Problems.DataAnalysis;
31using HeuristicLab.Problems.DataAnalysis.Evaluators;
32using HeuristicLab.Parameters;
33using HeuristicLab.Optimization;
34using HeuristicLab.Operators;
35
36namespace HeuristicLab.Problems.DataAnalysis.SupportVectorMachine.ParameterAdjustmentProblem {
37  [Item("SupportVectorMachineParameterAdjustmentBestSolutionAnalyzer", "Collects the parameters and the quality on training and test of the best solution for the SVM parameter adjustment problem.")]
38  [StorableClass]
39  public class SupportVectorMachineParameterAdjustmentBestSolutionAnalyzer : SingleSuccessorOperator, IAnalyzer {
40    private const string ParameterVectorParameterName = "ParameterVector";
41    private const string DataAnalysisProblemDataParameterName = "DataAnalysisProblemData";
42    private const string SvmTypeParameterName = "SvmType";
43    private const string KernelTypeParameterName = "KernelType";
44    private const string QualityParameterName = "Quality";
45    private const string BestSolutionParameterName = "BestSolution";
46    private const string BestSolutionQualityParameterName = "BestSolutionQuality";
47    private const string ResultsParameterName = "Results";
48    private const string BestSolutionResultName = "Best solution (cross-validation)";
49    private const string BestSolutionTrainingMse = "Best solution mean squared error (training)";
50    private const string BestSolutionTestMse = "Best solution mean squared error (test)";
51    private const string BestSolutionNu = "Best nu (cross-validation)";
52    private const string BestSolutionCost = "Best cost (cross-validation)";
53    private const string BestSolutionGamma = "Best gamma (cross-validation)";
54
55
56    #region parameter properties
57    public ILookupParameter<ItemArray<RealVector>> ParameterVectorParameter {
58      get { return (ILookupParameter<ItemArray<RealVector>>)Parameters["ParameterVector"]; }
59    }
60    public IValueLookupParameter<DataAnalysisProblemData> DataAnalysisProblemDataParameter {
61      get { return (IValueLookupParameter<DataAnalysisProblemData>)Parameters[DataAnalysisProblemDataParameterName]; }
62    }
63    public IValueLookupParameter<StringValue> SvmTypeParameter {
64      get { return (IValueLookupParameter<StringValue>)Parameters[SvmTypeParameterName]; }
65    }
66    public IValueLookupParameter<StringValue> KernelTypeParameter {
67      get { return (IValueLookupParameter<StringValue>)Parameters[KernelTypeParameterName]; }
68    }
69    public ILookupParameter<ItemArray<DoubleValue>> QualityParameter {
70      get { return (ILookupParameter<ItemArray<DoubleValue>>)Parameters[QualityParameterName]; }
71    }
72    #endregion
73    #region properties
74    public DataAnalysisProblemData DataAnalysisProblemData {
75      get { return DataAnalysisProblemDataParameter.ActualValue; }
76    }
77    public StringValue SvmType {
78      get { return SvmTypeParameter.Value; }
79    }
80    public StringValue KernelType {
81      get { return KernelTypeParameter.Value; }
82    }
83    public ILookupParameter<SupportVectorMachineModel> BestSolutionParameter {
84      get { return (ILookupParameter<SupportVectorMachineModel>)Parameters[BestSolutionParameterName]; }
85    }
86    public ILookupParameter<DoubleValue> BestSolutionQualityParameter {
87      get { return (ILookupParameter<DoubleValue>)Parameters[BestSolutionQualityParameterName]; }
88    }
89    public ILookupParameter<ResultCollection> ResultsParameter {
90      get { return (ILookupParameter<ResultCollection>)Parameters[ResultsParameterName]; }
91    }
92
93    #endregion
94
95    public SupportVectorMachineParameterAdjustmentBestSolutionAnalyzer()
96      : base() {
97      StringValue nuSvrType = new StringValue("NU_SVR").AsReadOnly();
98      StringValue rbfKernelType = new StringValue("RBF").AsReadOnly();
99      Parameters.Add(new ScopeTreeLookupParameter<RealVector>(ParameterVectorParameterName, "The parameters for the SVM encoded as a real vector."));
100      Parameters.Add(new ValueLookupParameter<DataAnalysisProblemData>(DataAnalysisProblemDataParameterName, "The data analysis problem data to use for training."));
101      Parameters.Add(new ValueLookupParameter<StringValue>(SvmTypeParameterName, "The type of SVM to use.", nuSvrType));
102      Parameters.Add(new ValueLookupParameter<StringValue>(KernelTypeParameterName, "The kernel type to use for the SVM.", rbfKernelType));
103      Parameters.Add(new ScopeTreeLookupParameter<DoubleValue>(QualityParameterName, "The cross validation quality reached with the given parameters."));
104      Parameters.Add(new LookupParameter<SupportVectorMachineModel>(BestSolutionParameterName, "The best support vector solution."));
105      Parameters.Add(new LookupParameter<DoubleValue>(BestSolutionQualityParameterName, "The quality of the best support vector model."));
106      Parameters.Add(new LookupParameter<ResultCollection>(ResultsParameterName, "The result collection where the best support vector solution should be stored."));
107    }
108
109    public override IOperation Apply() {
110      var points = ParameterVectorParameter.ActualValue;
111      var qualities = QualityParameter.ActualValue;
112      var bestPoint = points[0];
113      var bestQuality = qualities[0].Value;
114      for (int i = 1; i < points.Length; i++) {
115        if (bestQuality > qualities[i].Value) {
116          bestQuality = qualities[i].Value;
117          bestPoint = points[i];
118        }
119      }
120      ResultCollection results = ResultsParameter.ActualValue;
121      double nu = bestPoint[0];
122      double cost = Math.Pow(2, bestPoint[1]);
123      double gamma = Math.Pow(2, bestPoint[2]);
124      DataAnalysisProblemData problemData = DataAnalysisProblemData;
125
126      SupportVectorMachineModel bestModel = BestSolutionParameter.ActualValue;
127      if (bestModel == null) {
128        bestModel = SupportVectorMachineModelCreator.TrainModel(DataAnalysisProblemData,
129          DataAnalysisProblemData.TrainingSamplesStart.Value, DataAnalysisProblemData.TrainingSamplesEnd.Value,
130          SvmType.Value, KernelType.Value, cost, nu, gamma, 0.0);
131        BestSolutionParameter.ActualValue = bestModel;
132        BestSolutionQualityParameter.ActualValue = new DoubleValue(bestQuality);
133        results.Add(new Result(BestSolutionResultName, bestModel));
134        #region calculate R2,MSE,Rel Error
135        double[] trainingValues = problemData.Dataset.GetVariableValues(
136          problemData.TargetVariable.Value,
137          problemData.TrainingSamplesStart.Value,
138          problemData.TrainingSamplesEnd.Value);
139        double[] testValues = problemData.Dataset.GetVariableValues(
140          problemData.TargetVariable.Value,
141          problemData.TestSamplesStart.Value,
142          problemData.TestSamplesEnd.Value);
143        double[] estimatedTrainingValues = bestModel.GetEstimatedValues(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value)
144          .ToArray();
145        double[] estimatedTestValues = bestModel.GetEstimatedValues(problemData, problemData.TestSamplesStart.Value, problemData.TestSamplesEnd.Value)
146          .ToArray();
147        double trainingMse = SimpleMSEEvaluator.Calculate(trainingValues, estimatedTrainingValues);
148        double testMse = SimpleMSEEvaluator.Calculate(testValues, estimatedTestValues);
149        results.Add(new Result(BestSolutionTrainingMse, new DoubleValue(trainingMse)));
150        results.Add(new Result(BestSolutionTestMse, new DoubleValue(testMse)));
151        results.Add(new Result(BestSolutionNu, new DoubleValue(nu)));
152        results.Add(new Result(BestSolutionCost, new DoubleValue(cost)));
153        results.Add(new Result(BestSolutionGamma, new DoubleValue(gamma)));
154        #endregion
155      } else {
156        if (BestSolutionQualityParameter.ActualValue.Value > bestQuality) {
157          bestModel = SupportVectorMachineModelCreator.TrainModel(DataAnalysisProblemData,
158            DataAnalysisProblemData.TrainingSamplesStart.Value, DataAnalysisProblemData.TrainingSamplesEnd.Value,
159            SvmType.Value, KernelType.Value, cost, nu, gamma, 0.0);
160          BestSolutionParameter.ActualValue = bestModel;
161          BestSolutionQualityParameter.ActualValue = new DoubleValue(bestQuality);
162          results[BestSolutionResultName].Value = bestModel;
163          #region calculate R2,MSE,Rel Error
164          double[] trainingValues = problemData.Dataset.GetVariableValues(
165            problemData.TargetVariable.Value,
166            problemData.TrainingSamplesStart.Value,
167            problemData.TrainingSamplesEnd.Value);
168          double[] testValues = problemData.Dataset.GetVariableValues(
169            problemData.TargetVariable.Value,
170            problemData.TestSamplesStart.Value,
171            problemData.TestSamplesEnd.Value);
172          double[] estimatedTrainingValues = bestModel.GetEstimatedValues(problemData, problemData.TrainingSamplesStart.Value, problemData.TrainingSamplesEnd.Value)
173            .ToArray();
174          double[] estimatedTestValues = bestModel.GetEstimatedValues(problemData, problemData.TestSamplesStart.Value, problemData.TestSamplesEnd.Value)
175            .ToArray();
176          double trainingMse = SimpleMSEEvaluator.Calculate(trainingValues, estimatedTrainingValues);
177          double testMse = SimpleMSEEvaluator.Calculate(testValues, estimatedTestValues);
178          results[BestSolutionTrainingMse].Value = new DoubleValue(trainingMse);
179          results[BestSolutionTestMse].Value = new DoubleValue(testMse);
180          results[BestSolutionNu].Value = new DoubleValue(nu);
181          results[BestSolutionCost].Value = new DoubleValue(cost);
182          results[BestSolutionGamma].Value = new DoubleValue(gamma);
183          #endregion
184        }
185      }
186
187      return base.Apply();
188    }
189  }
190}
Note: See TracBrowser for help on using the repository browser.