Free cookie consent management tool by TermsFeed Policy Generator

source: branches/HeuristicLab.Modeling Database Backend/sources/HeuristicLab.LinearRegression/3.2/LinearRegression.cs @ 2201

Last change on this file since 2201 was 2201, checked in by gkronber, 15 years ago

Added statements to set the input variables of models in all regression engines. #712

File size: 17.1 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30using HeuristicLab.Data;
31using HeuristicLab.Operators;
32using HeuristicLab.GP.StructureIdentification;
33using HeuristicLab.Modeling;
34using HeuristicLab.GP;
35using HeuristicLab.Random;
36
37namespace HeuristicLab.LinearRegression {
38  public class LinearRegression : ItemBase, IEditable, IAlgorithm {
39
40    public string Name { get { return "LinearRegression"; } }
41    public string Description { get { return "TODO"; } }
42
43    private SequentialEngine.SequentialEngine engine;
44    public IEngine Engine {
45      get { return engine; }
46    }
47
48    public Dataset Dataset {
49      get { return ProblemInjector.GetVariableValue<Dataset>("Dataset", null, false); }
50      set { ProblemInjector.GetVariable("Dataset").Value = value; }
51    }
52
53    public int TargetVariable {
54      get { return ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data; }
55      set { ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data = value; }
56    }
57
58    public IOperator ProblemInjector {
59      get {
60        IOperator main = GetMainOperator();
61        return main.SubOperators[1];
62      }
63      set {
64        IOperator main = GetMainOperator();
65        main.RemoveSubOperator(1);
66        main.AddSubOperator(value, 1);
67      }
68    }
69
70    public IModel Model {
71      get {
72        if (!engine.Terminated) throw new InvalidOperationException("The algorithm is still running. Wait until the algorithm is terminated to retrieve the result.");
73        IScope bestModelScope = engine.GlobalScope;
74        return CreateLRModel(bestModelScope);
75      }
76    }
77
78    public LinearRegression() {
79      engine = new SequentialEngine.SequentialEngine();
80      CombinedOperator algo = CreateAlgorithm();
81      engine.OperatorGraph.AddOperator(algo);
82      engine.OperatorGraph.InitialOperator = algo;
83    }
84
85    private CombinedOperator CreateAlgorithm() {
86      CombinedOperator algo = new CombinedOperator();
87      SequentialProcessor seq = new SequentialProcessor();
88      algo.Name = "LinearRegression";
89      seq.Name = "LinearRegression";
90
91      var randomInjector = new RandomInjector();
92      randomInjector.Name = "Random Injector";
93      IOperator globalInjector = CreateGlobalInjector();
94      ProblemInjector problemInjector = new ProblemInjector();
95      problemInjector.GetVariableInfo("MaxNumberOfTrainingSamples").Local = true;
96      problemInjector.AddVariable(new HeuristicLab.Core.Variable("MaxNumberOfTrainingSamples", new IntData(5000)));
97
98      IOperator shuffler = new DatasetShuffler();
99      shuffler.GetVariableInfo("ShuffleStart").ActualName = "TrainingSamplesStart";
100      shuffler.GetVariableInfo("ShuffleEnd").ActualName = "TrainingSamplesEnd";
101
102      LinearRegressionOperator lrOperator = new LinearRegressionOperator();
103      lrOperator.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
104      lrOperator.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
105
106      seq.AddSubOperator(randomInjector);
107      seq.AddSubOperator(problemInjector);
108      seq.AddSubOperator(globalInjector);
109      seq.AddSubOperator(shuffler);
110      seq.AddSubOperator(lrOperator);
111      seq.AddSubOperator(CreateModelAnalyser());
112
113      algo.OperatorGraph.InitialOperator = seq;
114      algo.OperatorGraph.AddOperator(seq);
115
116      return algo;
117    }
118
119    private IOperator CreateGlobalInjector() {
120      VariableInjector injector = new VariableInjector();
121      injector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(10)));
122      injector.AddVariable(new HeuristicLab.Core.Variable("TotalEvaluatedNodes", new DoubleData(0)));
123      injector.AddVariable(new HeuristicLab.Core.Variable("TreeEvaluator", new HL2TreeEvaluator()));
124      injector.AddVariable(new HeuristicLab.Core.Variable("UseEstimatedTargetValue", new BoolData(false)));
125
126      return injector;
127    }
128
129    private IOperator CreateModelAnalyser() {
130      CombinedOperator modelAnalyser = new CombinedOperator();
131      modelAnalyser.Name = "Model Analyzer";
132      SequentialProcessor seqProc = new SequentialProcessor();
133      #region MSE
134      MeanSquaredErrorEvaluator trainingMSE = new MeanSquaredErrorEvaluator();
135      trainingMSE.Name = "TrainingMseEvaluator";
136      trainingMSE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
137      trainingMSE.GetVariableInfo("MSE").ActualName = "TrainingQuality";
138      trainingMSE.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
139      trainingMSE.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
140      MeanSquaredErrorEvaluator validationMSE = new MeanSquaredErrorEvaluator();
141      validationMSE.Name = "ValidationMseEvaluator";
142      validationMSE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
143      validationMSE.GetVariableInfo("MSE").ActualName = "ValidationQuality";
144      validationMSE.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
145      validationMSE.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
146      MeanSquaredErrorEvaluator testMSE = new MeanSquaredErrorEvaluator();
147      testMSE.Name = "TestMseEvaluator";
148      testMSE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
149      testMSE.GetVariableInfo("MSE").ActualName = "TestQuality";
150      testMSE.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
151      testMSE.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
152      #endregion
153     
154      #region R2
155      CoefficientOfDeterminationEvaluator trainingR2 = new CoefficientOfDeterminationEvaluator();
156      trainingR2.Name = "TrainingR2Evaluator";
157      trainingR2.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
158      trainingR2.GetVariableInfo("R2").ActualName = "TrainingR2";
159      trainingR2.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
160      trainingR2.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
161      CoefficientOfDeterminationEvaluator validationR2 = new CoefficientOfDeterminationEvaluator();
162      validationR2.Name = "ValidationR2Evaluator";
163      validationR2.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
164      validationR2.GetVariableInfo("R2").ActualName = "ValidationR2";
165      validationR2.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
166      validationR2.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
167      CoefficientOfDeterminationEvaluator testR2 = new CoefficientOfDeterminationEvaluator();
168      testR2.Name = "TestR2Evaluator";
169      testR2.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
170      testR2.GetVariableInfo("R2").ActualName = "TestR2";
171      testR2.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
172      testR2.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
173      #endregion
174
175      #region MAPE
176      MeanAbsolutePercentageErrorEvaluator trainingMAPE = new MeanAbsolutePercentageErrorEvaluator();
177      trainingMAPE.Name = "TrainingMapeEvaluator";
178      trainingMAPE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
179      trainingMAPE.GetVariableInfo("MAPE").ActualName = "TrainingMAPE";
180      trainingMAPE.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
181      trainingMAPE.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
182      MeanAbsolutePercentageErrorEvaluator validationMAPE = new MeanAbsolutePercentageErrorEvaluator();
183      validationMAPE.Name = "ValidationMapeEvaluator";
184      validationMAPE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
185      validationMAPE.GetVariableInfo("MAPE").ActualName = "ValidationMAPE";
186      validationMAPE.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
187      validationMAPE.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
188      MeanAbsolutePercentageErrorEvaluator testMAPE = new MeanAbsolutePercentageErrorEvaluator();
189      testMAPE.Name = "TestMapeEvaluator";
190      testMAPE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
191      testMAPE.GetVariableInfo("MAPE").ActualName = "TestMAPE";
192      testMAPE.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
193      testMAPE.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
194      #endregion
195
196      #region MAPRE
197      MeanAbsolutePercentageOfRangeErrorEvaluator trainingMAPRE = new MeanAbsolutePercentageOfRangeErrorEvaluator();
198      trainingMAPRE.Name = "TrainingMapreEvaluator";
199      trainingMAPRE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
200      trainingMAPRE.GetVariableInfo("MAPRE").ActualName = "TrainingMAPRE";
201      trainingMAPRE.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
202      trainingMAPRE.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
203      MeanAbsolutePercentageOfRangeErrorEvaluator validationMAPRE = new MeanAbsolutePercentageOfRangeErrorEvaluator();
204      validationMAPRE.Name = "ValidationMapreEvaluator";
205      validationMAPRE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
206      validationMAPRE.GetVariableInfo("MAPRE").ActualName = "ValidationMAPRE";
207      validationMAPRE.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
208      validationMAPRE.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
209      MeanAbsolutePercentageOfRangeErrorEvaluator testMAPRE = new MeanAbsolutePercentageOfRangeErrorEvaluator();
210      testMAPRE.Name = "TestMapreEvaluator";
211      testMAPRE.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
212      testMAPRE.GetVariableInfo("MAPRE").ActualName = "TestMAPRE";
213      testMAPRE.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
214      testMAPRE.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
215      #endregion
216
217      #region VAF
218      VarianceAccountedForEvaluator trainingVAF = new VarianceAccountedForEvaluator();
219      trainingVAF.Name = "TrainingVafEvaluator";
220      trainingVAF.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
221      trainingVAF.GetVariableInfo("VAF").ActualName = "TrainingVAF";
222      trainingVAF.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
223      trainingVAF.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
224      VarianceAccountedForEvaluator validationVAF = new VarianceAccountedForEvaluator();
225      validationVAF.Name = "ValidationVafEvaluator";
226      validationVAF.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
227      validationVAF.GetVariableInfo("VAF").ActualName = "ValidationVAF";
228      validationVAF.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
229      validationVAF.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
230      VarianceAccountedForEvaluator testVAF = new VarianceAccountedForEvaluator();
231      testVAF.Name = "TestVafEvaluator";
232      testVAF.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
233      testVAF.GetVariableInfo("VAF").ActualName = "TestVAF";
234      testVAF.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
235      testVAF.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
236      #endregion
237
238      HeuristicLab.GP.StructureIdentification.VariableEvaluationImpactCalculator evalImpactCalc = new HeuristicLab.GP.StructureIdentification.VariableEvaluationImpactCalculator();
239      evalImpactCalc.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
240      HeuristicLab.Modeling.VariableQualityImpactCalculator qualImpactCalc = new HeuristicLab.GP.StructureIdentification.VariableQualityImpactCalculator();
241      qualImpactCalc.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
242      seqProc.AddSubOperator(trainingMSE);
243      seqProc.AddSubOperator(validationMSE);
244      seqProc.AddSubOperator(testMSE);
245      seqProc.AddSubOperator(trainingR2);
246      seqProc.AddSubOperator(validationR2);
247      seqProc.AddSubOperator(testR2);
248      seqProc.AddSubOperator(trainingMAPE);
249      seqProc.AddSubOperator(validationMAPE);
250      seqProc.AddSubOperator(testMAPE);
251      seqProc.AddSubOperator(trainingMAPRE);
252      seqProc.AddSubOperator(validationMAPRE);
253      seqProc.AddSubOperator(testMAPRE);
254      seqProc.AddSubOperator(trainingVAF);
255      seqProc.AddSubOperator(validationVAF);
256      seqProc.AddSubOperator(testVAF);
257      seqProc.AddSubOperator(qualImpactCalc);
258      seqProc.AddSubOperator(evalImpactCalc);
259      modelAnalyser.OperatorGraph.InitialOperator = seqProc;
260      modelAnalyser.OperatorGraph.AddOperator(seqProc);
261      return modelAnalyser;
262    }
263
264
265    protected internal virtual Model CreateLRModel(IScope bestModelScope) {
266      Model model = new Model();
267      model.TrainingMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("TrainingQuality", false).Data;
268      model.ValidationMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("ValidationQuality", false).Data;
269      model.TestMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("TestQuality", false).Data;
270      model.TrainingCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("TrainingR2", false).Data;
271      model.ValidationCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("ValidationR2", false).Data;
272      model.TestCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("TestR2", false).Data;
273      model.TrainingMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("TrainingMAPE", false).Data;
274      model.ValidationMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("ValidationMAPE", false).Data;
275      model.TestMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("TestMAPE", false).Data;
276      model.TrainingMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("TrainingMAPRE", false).Data;
277      model.ValidationMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("ValidationMAPRE", false).Data;
278      model.TestMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("TestMAPRE", false).Data;
279      model.TrainingVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("TrainingVAF", false).Data;
280      model.ValidationVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("ValidationVAF", false).Data;
281      model.TestVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("TestVAF", false).Data;
282
283      model.Data = bestModelScope.GetVariableValue<IFunctionTree>("LinearRegressionModel", false);
284      HeuristicLab.DataAnalysis.Dataset ds = bestModelScope.GetVariableValue<Dataset>("Dataset", true);
285      model.Dataset = ds;
286      model.TargetVariable = ds.GetVariableName(bestModelScope.GetVariableValue<IntData>("TargetVariable", true).Data);
287
288      ItemList evaluationImpacts = bestModelScope.GetVariableValue<ItemList>("VariableEvaluationImpacts", false);
289      ItemList qualityImpacts = bestModelScope.GetVariableValue<ItemList>("VariableQualityImpacts", false);
290      foreach (ItemList row in evaluationImpacts) {
291        string variableName = ((StringData)row[0]).Data;
292        double impact = ((DoubleData)row[1]).Data;
293        model.SetVariableEvaluationImpact(variableName, impact);
294        model.AddInputVariables(variableName);
295      }
296      foreach (ItemList row in qualityImpacts) {
297        string variableName = ((StringData)row[0]).Data;
298        double impact = ((DoubleData)row[1]).Data;
299        model.SetVariableQualityImpact(variableName, impact);
300        model.AddInputVariables(variableName);
301      }
302
303      return model;
304    }
305
306    private IOperator GetMainOperator() {
307      CombinedOperator lr = (CombinedOperator)Engine.OperatorGraph.InitialOperator;
308      return lr.OperatorGraph.InitialOperator;
309    }
310
311    public override IView CreateView() {
312      return engine.CreateView();
313    }
314
315    #region IEditable Members
316
317    public IEditor CreateEditor() {
318      return engine.CreateEditor();
319    }
320
321    #endregion
322  }
323}
Note: See TracBrowser for help on using the repository browser.