Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.SupportVectorMachines/3.2/SupportVectorRegression.cs @ 2051

Last change on this file since 2051 was 2051, checked in by gkronber, 15 years ago

Fixed a bug in calculation of variable impacts in GP and SVM algorithms. #644

File size: 19.0 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30using HeuristicLab.Data;
31using HeuristicLab.Operators;
32using HeuristicLab.GP.StructureIdentification;
33using HeuristicLab.Logging;
34using HeuristicLab.Operators.Programmable;
35using HeuristicLab.Modeling;
36using HeuristicLab.Random;
37using HeuristicLab.Selection;
38
39namespace HeuristicLab.SupportVectorMachines {
40  public class SupportVectorRegression : ItemBase, IEditable, IAlgorithm {
41
42    public string Name { get { return "SupportVectorRegression"; } }
43    public string Description { get { return "TODO"; } }
44
45    private SequentialEngine.SequentialEngine engine;
46    public IEngine Engine {
47      get { return engine; }
48    }
49
50    public Dataset Dataset {
51      get { return ProblemInjector.GetVariableValue<Dataset>("Dataset", null, false); }
52      set { ProblemInjector.GetVariable("Dataset").Value = value; }
53    }
54
55    public int TargetVariable {
56      get { return ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data; }
57      set { ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data = value; }
58    }
59
60    public IOperator ProblemInjector {
61      get {
62        IOperator main = GetMainOperator();
63        return main.SubOperators[1];
64      }
65      set {
66        IOperator main = GetMainOperator();
67        main.RemoveSubOperator(1);
68        main.AddSubOperator(value, 1);
69      }
70    }
71
72    public IModel Model {
73      get {
74        if (!engine.Terminated) throw new InvalidOperationException("The algorithm is still running. Wait until the algorithm is terminated to retrieve the result.");
75        IScope bestModelScope = engine.GlobalScope.SubScopes[0];
76        return CreateSVMModel(bestModelScope);
77      }
78    }
79
80    public DoubleArrayData NuList {
81      get { return GetVariableInjector().GetVariable("NuList").GetValue<DoubleArrayData>(); }
82      set { GetVariableInjector().GetVariable("NuList").Value = value; }
83    }
84
85    public int MaxNuIndex {
86      get { return GetVariableInjector().GetVariable("MaxNuIndex").GetValue<IntData>().Data; }
87      set { GetVariableInjector().GetVariable("MaxNuIndex").GetValue<IntData>().Data = value; }
88    }
89
90    public DoubleArrayData CostList {
91      get { return GetVariableInjector().GetVariable("CostList").GetValue<DoubleArrayData>(); }
92      set { GetVariableInjector().GetVariable("CostList").Value = value; }
93    }
94
95    public int MaxCostIndex {
96      get { return GetVariableInjector().GetVariable("MaxCostIndex").GetValue<IntData>().Data; }
97      set { GetVariableInjector().GetVariable("MaxCostIndex").GetValue<IntData>().Data = value; }
98    }
99
100    public SupportVectorRegression() {
101      engine = new SequentialEngine.SequentialEngine();
102      CombinedOperator algo = CreateAlgorithm();
103      engine.OperatorGraph.AddOperator(algo);
104      engine.OperatorGraph.InitialOperator = algo;
105      MaxCostIndex = CostList.Data.Length;
106      MaxNuIndex = NuList.Data.Length;
107    }
108
109    private CombinedOperator CreateAlgorithm() {
110      CombinedOperator algo = new CombinedOperator();
111      algo.Name = "SupportVectorRegression";
112      IOperator main = CreateMainLoop();
113      algo.OperatorGraph.AddOperator(main);
114      algo.OperatorGraph.InitialOperator = main;
115      return algo;
116    }
117
118    private IOperator CreateMainLoop() {
119      SequentialProcessor main = new SequentialProcessor();
120      main.AddSubOperator(CreateGlobalInjector());
121      main.AddSubOperator(new ProblemInjector());
122      main.AddSubOperator(new RandomInjector());
123
124      SubScopesCreater modelScopeCreator = new SubScopesCreater();
125      modelScopeCreator.GetVariableInfo("SubScopes").Local = true;
126      modelScopeCreator.AddVariable(new HeuristicLab.Core.Variable("SubScopes", new IntData(1)));
127      main.AddSubOperator(modelScopeCreator);
128
129      SequentialSubScopesProcessor seqSubScopesProc = new SequentialSubScopesProcessor();
130      IOperator modelProcessor = CreateModelProcessor();
131      seqSubScopesProc.AddSubOperator(modelProcessor);
132      main.AddSubOperator(seqSubScopesProc);
133
134      SequentialProcessor nuLoop = new SequentialProcessor();
135      nuLoop.Name = "NuLoop";
136
137      IOperator costCounter = CreateCounter("Cost");
138      IOperator costComparator = CreateComparator("Cost");
139      nuLoop.AddSubOperator(costCounter);
140      nuLoop.AddSubOperator(costComparator);
141      ConditionalBranch costBranch = new ConditionalBranch();
142      costBranch.Name = "IfValidCostIndex";
143      costBranch.GetVariableInfo("Condition").ActualName = "RepeatCostLoop";
144
145      // build cost loop
146      SequentialProcessor costLoop = new SequentialProcessor();
147      costLoop.Name = "CostLoop";
148      costLoop.AddSubOperator(modelScopeCreator);
149      SequentialSubScopesProcessor subScopesProcessor = new SequentialSubScopesProcessor();
150      costLoop.AddSubOperator(subScopesProcessor);
151      subScopesProcessor.AddSubOperator(new EmptyOperator());
152      subScopesProcessor.AddSubOperator(modelProcessor);
153
154      Sorter sorter = new Sorter();
155      sorter.GetVariableInfo("Value").ActualName = "ValidationQuality";
156      sorter.GetVariableInfo("Descending").Local = true;
157      sorter.AddVariable(new Variable("Descending", new BoolData(false)));
158      costLoop.AddSubOperator(sorter);
159
160      LeftSelector selector = new LeftSelector();
161      selector.GetVariableInfo("Selected").Local = true;
162      selector.AddVariable(new Variable("Selected", new IntData(1)));
163      costLoop.AddSubOperator(selector);
164
165      RightReducer reducer = new RightReducer();
166      costLoop.AddSubOperator(reducer);
167
168      costLoop.AddSubOperator(costCounter);
169      costLoop.AddSubOperator(costComparator);
170
171      costBranch.AddSubOperator(costLoop);
172      costLoop.AddSubOperator(costBranch);
173
174      nuLoop.AddSubOperator(costBranch);
175      nuLoop.AddSubOperator(CreateResetOperator("CostIndex"));
176
177      nuLoop.AddSubOperator(CreateCounter("Nu"));
178      nuLoop.AddSubOperator(CreateComparator("Nu"));
179
180      ConditionalBranch nuBranch = new ConditionalBranch();
181      nuBranch.Name = "NuLoop";
182      nuBranch.GetVariableInfo("Condition").ActualName = "RepeatNuLoop";
183      nuBranch.AddSubOperator(nuLoop);
184      nuLoop.AddSubOperator(nuBranch);
185
186      main.AddSubOperator(nuLoop);
187      main.AddSubOperator(CreateModelAnalyser());
188      return main;
189    }
190
191    private IOperator CreateModelProcessor() {
192      SequentialProcessor modelProcessor = new SequentialProcessor();
193      modelProcessor.AddSubOperator(CreateSetNextParameterValueOperator("Nu"));
194      modelProcessor.AddSubOperator(CreateSetNextParameterValueOperator("Cost"));
195
196      SupportVectorCreator modelCreator = new SupportVectorCreator();
197      modelCreator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
198      modelCreator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
199      modelCreator.GetVariableInfo("SVMCost").ActualName = "Cost";
200      modelCreator.GetVariableInfo("SVMGamma").ActualName = "Gamma";
201      modelCreator.GetVariableInfo("SVMKernelType").ActualName = "KernelType";
202      modelCreator.GetVariableInfo("SVMModel").ActualName = "Model";
203      modelCreator.GetVariableInfo("SVMNu").ActualName = "Nu";
204      modelCreator.GetVariableInfo("SVMType").ActualName = "Type";
205
206      modelProcessor.AddSubOperator(modelCreator);
207      CombinedOperator trainingEvaluator = (CombinedOperator)CreateEvaluator("Training");
208      trainingEvaluator.OperatorGraph.InitialOperator.SubOperators[1].GetVariableInfo("MSE").ActualName = "Quality";
209      modelProcessor.AddSubOperator(trainingEvaluator);
210      modelProcessor.AddSubOperator(CreateEvaluator("Validation"));
211      modelProcessor.AddSubOperator(CreateEvaluator("Test"));
212
213      DataCollector collector = new DataCollector();
214      collector.GetVariableInfo("Values").ActualName = "Log";
215      ((ItemList<StringData>)collector.GetVariable("VariableNames").Value).Add(new StringData("Nu"));
216      ((ItemList<StringData>)collector.GetVariable("VariableNames").Value).Add(new StringData("Cost"));
217      ((ItemList<StringData>)collector.GetVariable("VariableNames").Value).Add(new StringData("ValidationQuality"));
218      modelProcessor.AddSubOperator(collector);
219      return modelProcessor;
220    }
221
222    private IOperator CreateComparator(string p) {
223      LessThanComparator comparator = new LessThanComparator();
224      comparator.Name = p + "IndexComparator";
225      comparator.GetVariableInfo("LeftSide").ActualName = p + "Index";
226      comparator.GetVariableInfo("RightSide").ActualName = "Max" + p + "Index";
227      comparator.GetVariableInfo("Result").ActualName = "Repeat" + p + "Loop";
228      return comparator;
229    }
230
231    private IOperator CreateCounter(string p) {
232      Counter c = new Counter();
233      c.GetVariableInfo("Value").ActualName = p + "Index";
234      c.Name = p + "Counter";
235      return c;
236    }
237
238    private IOperator CreateEvaluator(string p) {
239      CombinedOperator op = new CombinedOperator();
240      op.Name = p + "Evaluator";
241      SequentialProcessor seqProc = new SequentialProcessor();
242
243      SupportVectorEvaluator evaluator = new SupportVectorEvaluator();
244      evaluator.Name = p + "SimpleEvaluator";
245      evaluator.GetVariableInfo("SVMModel").ActualName = "Model";
246      evaluator.GetVariableInfo("SamplesStart").ActualName = p + "SamplesStart";
247      evaluator.GetVariableInfo("SamplesEnd").ActualName = p + "SamplesEnd";
248      evaluator.GetVariableInfo("Values").ActualName = p + "Values";
249      SimpleMSEEvaluator mseEvaluator = new SimpleMSEEvaluator();
250      mseEvaluator.Name = p + "MseEvaluator";
251      mseEvaluator.GetVariableInfo("Values").ActualName = p + "Values";
252      mseEvaluator.GetVariableInfo("MSE").ActualName = p + "Quality";
253      SimpleR2Evaluator r2Evaluator = new SimpleR2Evaluator();
254      r2Evaluator.Name = p + "R2Evaluator";
255      r2Evaluator.GetVariableInfo("Values").ActualName = p + "Values";
256      r2Evaluator.GetVariableInfo("R2").ActualName = p + "R2";
257      SimpleMeanAbsolutePercentageErrorEvaluator mapeEvaluator = new SimpleMeanAbsolutePercentageErrorEvaluator();
258      mapeEvaluator.Name = p + "MAPEEvaluator";
259      mapeEvaluator.GetVariableInfo("Values").ActualName = p + "Values";
260      mapeEvaluator.GetVariableInfo("MAPE").ActualName = p + "MAPE";
261      SimpleMeanAbsolutePercentageOfRangeErrorEvaluator mapreEvaluator = new SimpleMeanAbsolutePercentageOfRangeErrorEvaluator();
262      mapreEvaluator.Name = p + "MAPREEvaluator";
263      mapreEvaluator.GetVariableInfo("Values").ActualName = p + "Values";
264      mapreEvaluator.GetVariableInfo("MAPRE").ActualName = p + "MAPRE";
265      SimpleVarianceAccountedForEvaluator vafEvaluator = new SimpleVarianceAccountedForEvaluator();
266      vafEvaluator.Name = p + "VAFEvaluator";
267      vafEvaluator.GetVariableInfo("Values").ActualName = p + "Values";
268      vafEvaluator.GetVariableInfo("VAF").ActualName = p + "VAF";
269
270      seqProc.AddSubOperator(evaluator);
271      seqProc.AddSubOperator(mseEvaluator);
272      seqProc.AddSubOperator(r2Evaluator);
273      seqProc.AddSubOperator(mapeEvaluator);
274      seqProc.AddSubOperator(mapreEvaluator);
275      seqProc.AddSubOperator(vafEvaluator);
276
277      op.OperatorGraph.AddOperator(seqProc);
278      op.OperatorGraph.InitialOperator = seqProc;
279      return op;
280    }
281
282    private IOperator CreateSetNextParameterValueOperator(string paramName) {
283      ProgrammableOperator progOp = new ProgrammableOperator();
284      progOp.Name = "SetNext" + paramName;
285      progOp.RemoveVariableInfo("Result");
286      progOp.AddVariableInfo(new VariableInfo("Value", "Value", typeof(DoubleData), VariableKind.New));
287      progOp.AddVariableInfo(new VariableInfo("ValueIndex", "ValueIndex", typeof(IntData), VariableKind.In));
288      progOp.AddVariableInfo(new VariableInfo("ValueList", "ValueList", typeof(DoubleArrayData), VariableKind.In));
289      progOp.Code =
290@"
291Value.Data = ValueList.Data[ValueIndex.Data];
292";
293
294      progOp.GetVariableInfo("Value").ActualName = paramName;
295      progOp.GetVariableInfo("ValueIndex").ActualName = paramName + "Index";
296      progOp.GetVariableInfo("ValueList").ActualName = paramName + "List";
297      return progOp;
298    }
299
300    private IOperator CreateResetOperator(string paramName) {
301      ProgrammableOperator progOp = new ProgrammableOperator();
302      progOp.Name = "Reset" + paramName;
303      progOp.RemoveVariableInfo("Result");
304      progOp.AddVariableInfo(new VariableInfo("Value", "Value", typeof(IntData), VariableKind.In | VariableKind.Out));
305      progOp.Code = "Value.Data = -1;";
306      progOp.GetVariableInfo("Value").ActualName = paramName;
307      return progOp;
308    }
309
310    private IOperator CreateGlobalInjector() {
311      VariableInjector injector = new VariableInjector();
312      injector.AddVariable(new HeuristicLab.Core.Variable("CostIndex", new IntData(0)));
313      injector.AddVariable(new HeuristicLab.Core.Variable("CostList", new DoubleArrayData(new double[] { 0.1, 0.25, 0.5, 1.0, 2.0, 4.0, 8.0 })));
314      injector.AddVariable(new HeuristicLab.Core.Variable("MaxCostIndex", new IntData()));
315      injector.AddVariable(new HeuristicLab.Core.Variable("NuIndex", new IntData(0)));
316      injector.AddVariable(new HeuristicLab.Core.Variable("NuList", new DoubleArrayData(new double[] { 0.01, 0.05, 0.1, 0.5 })));
317      injector.AddVariable(new HeuristicLab.Core.Variable("MaxNuIndex", new IntData()));
318      injector.AddVariable(new HeuristicLab.Core.Variable("Log", new ItemList()));
319      injector.AddVariable(new HeuristicLab.Core.Variable("Gamma", new DoubleData(1)));
320      injector.AddVariable(new HeuristicLab.Core.Variable("KernelType", new StringData("RBF")));
321      injector.AddVariable(new HeuristicLab.Core.Variable("Type", new StringData("NU_SVR")));
322
323      return injector;
324    }
325
326    private IOperator CreateModelAnalyser() {
327      CombinedOperator modelAnalyser = new CombinedOperator();
328      modelAnalyser.Name = "Model Analyzer";
329      SequentialSubScopesProcessor seqSubScopeProc = new SequentialSubScopesProcessor();
330      SequentialProcessor seqProc = new SequentialProcessor();
331      VariableEvaluationImpactCalculator evalImpactCalc = new VariableEvaluationImpactCalculator();
332      evalImpactCalc.GetVariableInfo("SVMModel").ActualName = "Model";
333      VariableQualityImpactCalculator qualImpactCalc = new VariableQualityImpactCalculator();
334      qualImpactCalc.GetVariableInfo("SVMModel").ActualName = "Model";
335
336      seqProc.AddSubOperator(evalImpactCalc);
337      seqProc.AddSubOperator(qualImpactCalc);
338      seqSubScopeProc.AddSubOperator(seqProc);
339      modelAnalyser.OperatorGraph.InitialOperator = seqSubScopeProc;
340      modelAnalyser.OperatorGraph.AddOperator(seqSubScopeProc);
341      return modelAnalyser;
342    }
343
344
345    protected internal virtual Model CreateSVMModel(IScope bestModelScope) {
346      Model model = new Model();
347      model.TrainingMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("Quality", false).Data;
348      model.ValidationMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("ValidationQuality", false).Data;
349      model.TestMeanSquaredError = bestModelScope.GetVariableValue<DoubleData>("TestQuality", false).Data;
350      model.TrainingCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("TrainingR2", false).Data;
351      model.ValidationCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("ValidationR2", false).Data;
352      model.TestCoefficientOfDetermination = bestModelScope.GetVariableValue<DoubleData>("TestR2", false).Data;
353      model.TrainingMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("TrainingMAPE", false).Data;
354      model.ValidationMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("ValidationMAPE", false).Data;
355      model.TestMeanAbsolutePercentageError = bestModelScope.GetVariableValue<DoubleData>("TestMAPE", false).Data;
356      model.TrainingMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("TrainingMAPRE", false).Data;
357      model.ValidationMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("ValidationMAPRE", false).Data;
358      model.TestMeanAbsolutePercentageOfRangeError = bestModelScope.GetVariableValue<DoubleData>("TestMAPRE", false).Data;
359      model.TrainingVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("TrainingVAF", false).Data;
360      model.ValidationVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("ValidationVAF", false).Data;
361      model.TestVarianceAccountedFor = bestModelScope.GetVariableValue<DoubleData>("TestVAF", false).Data;
362
363      model.Data = bestModelScope.GetVariableValue<SVMModel>("BestValidationModel", false);
364      HeuristicLab.DataAnalysis.Dataset ds = bestModelScope.GetVariableValue<Dataset>("Dataset", true);
365      model.Dataset = ds;
366      model.TargetVariable = ds.GetVariableName(bestModelScope.GetVariableValue<IntData>("TargetVariable", true).Data);
367
368      ItemList evaluationImpacts = bestModelScope.GetVariableValue<ItemList>("VariableEvaluationImpacts", false);
369      ItemList qualityImpacts = bestModelScope.GetVariableValue<ItemList>("VariableQualityImpacts", false);
370      foreach (ItemList row in evaluationImpacts) {
371        string variableName = ((StringData)row[0]).Data;
372        double impact = ((DoubleData)row[1]).Data;
373        model.SetVariableEvaluationImpact(variableName, impact);
374      }
375      foreach (ItemList row in qualityImpacts) {
376        string variableName = ((StringData)row[0]).Data;
377        double impact = ((DoubleData)row[1]).Data;
378        model.SetVariableQualityImpact(variableName, impact);
379      }
380
381      return model;
382    }
383
384    private IOperator GetVariableInjector() {
385      return GetMainOperator().SubOperators[0];
386    }
387
388    private IOperator GetMainOperator() {
389      CombinedOperator svm = (CombinedOperator)Engine.OperatorGraph.InitialOperator;
390      return svm.OperatorGraph.InitialOperator;
391    }
392
393    public override IView CreateView() {
394      return engine.CreateView();
395    }
396
397    #region IEditable Members
398
399    public IEditor CreateEditor() {
400      return engine.CreateEditor();
401    }
402
403    #endregion
404  }
405}
Note: See TracBrowser for help on using the repository browser.