Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.LinearRegression/3.2/LinearRegression.cs @ 2954

Last change on this file since 2954 was 2569, checked in by gkronber, 15 years ago

Added calculation of tree complexity to LR models. #821

File size: 13.3 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30using HeuristicLab.Data;
31using HeuristicLab.Operators;
32using HeuristicLab.GP.StructureIdentification;
33using HeuristicLab.Modeling;
34using HeuristicLab.GP;
35using HeuristicLab.Random;
36using HeuristicLab.GP.Interfaces;
37
38namespace HeuristicLab.LinearRegression {
39  public class LinearRegression : ItemBase, IEditable, IAlgorithm {
40
41    public virtual string Name { get { return "LinearRegression"; } }
42    public virtual string Description { get { return "TODO"; } }
43
44    private IEngine engine;
45    public virtual IEngine Engine {
46      get { return engine; }
47    }
48
49    public virtual Dataset Dataset {
50      get { return ProblemInjector.GetVariableValue<Dataset>("Dataset", null, false); }
51      set { ProblemInjector.GetVariable("Dataset").Value = value; }
52    }
53
54    public virtual string TargetVariable {
55      get { return ProblemInjector.GetVariableValue<StringData>("TargetVariable", null, false).Data; }
56      set { ProblemInjector.GetVariableValue<StringData>("TargetVariable", null, false).Data = value; }
57    }
58
59    public virtual IOperator ProblemInjector {
60      get {
61        IOperator main = GetMainOperator();
62        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
63        return probInjector.OperatorGraph.InitialOperator.SubOperators[0];
64      }
65      set {
66        IOperator main = GetMainOperator();
67        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
68        probInjector.OperatorGraph.InitialOperator.RemoveSubOperator(0);
69        probInjector.OperatorGraph.InitialOperator.AddSubOperator(value, 0);
70      }
71    }
72    public IEnumerable<string> AllowedVariables {
73      get {
74        ItemList<StringData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<StringData>>("AllowedFeatures", null, false);
75        return allowedVariables.Select(x => x.Data);
76      }
77      set {
78        ItemList<StringData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<StringData>>("AllowedFeatures", null, false);
79        foreach (string x in value) allowedVariables.Add(new StringData(x));
80      }
81    }
82
83    public int TrainingSamplesStart {
84      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data; }
85      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data = value; }
86    }
87
88    public int TrainingSamplesEnd {
89      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data; }
90      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data = value; }
91    }
92
93    public int ValidationSamplesStart {
94      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data; }
95      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data = value; }
96    }
97
98    public int ValidationSamplesEnd {
99      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data; }
100      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data = value; }
101    }
102
103    public int TestSamplesStart {
104      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data; }
105      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data = value; }
106    }
107
108    public int TestSamplesEnd {
109      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data; }
110      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data = value; }
111    }
112
113    public virtual IAnalyzerModel Model {
114      get {
115        if (!engine.Terminated) throw new InvalidOperationException("The algorithm is still running. Wait until the algorithm is terminated to retrieve the result.");
116        IScope bestModelScope = engine.GlobalScope;
117        return CreateLRModel(bestModelScope);
118      }
119    }
120
121    public LinearRegression() {
122      engine = new SequentialEngine.SequentialEngine();
123      CombinedOperator algo = CreateAlgorithm();
124      engine.OperatorGraph.AddOperator(algo);
125      engine.OperatorGraph.InitialOperator = algo;
126    }
127
128    protected virtual CombinedOperator CreateAlgorithm() {
129      CombinedOperator algo = new CombinedOperator();
130      SequentialProcessor seq = new SequentialProcessor();
131      algo.Name = Name;
132      seq.Name = Name;
133
134      IOperator globalInjector = CreateGlobalInjector();
135
136      HL3TreeEvaluatorInjector treeEvaluatorInjector = new HL3TreeEvaluatorInjector();
137
138      LinearRegressionOperator lrOperator = new LinearRegressionOperator();
139      lrOperator.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
140      lrOperator.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
141
142      seq.AddSubOperator(globalInjector);
143      seq.AddSubOperator(new RandomInjector());
144      seq.AddSubOperator(CreateProblemInjector());
145      seq.AddSubOperator(treeEvaluatorInjector);
146      seq.AddSubOperator(lrOperator);
147      seq.AddSubOperator(CreatePostProcessingOperator());
148
149      algo.OperatorGraph.InitialOperator = seq;
150      algo.OperatorGraph.AddOperator(seq);
151
152      return algo;
153    }
154
155    protected virtual IOperator CreateProblemInjector() {
156      return DefaultRegressionOperators.CreateProblemInjector();
157    }
158
159    protected virtual VariableInjector CreateGlobalInjector() {
160      VariableInjector injector = new VariableInjector();
161      injector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000)));
162      injector.AddVariable(new HeuristicLab.Core.Variable("TotalEvaluatedNodes", new DoubleData(0)));
163      injector.AddVariable(new HeuristicLab.Core.Variable("MaxNumberOfTrainingSamples", new IntData(4000)));
164
165      return injector;
166    }
167
168    protected virtual IOperator CreatePostProcessingOperator() {
169      CombinedOperator op = new CombinedOperator();
170      op.Name = "Model Analyzer";
171
172      SequentialProcessor seq = new SequentialProcessor();
173      HL3TreeEvaluatorInjector evaluatorInjector = new HL3TreeEvaluatorInjector();
174      evaluatorInjector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000.0)));
175      evaluatorInjector.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
176
177      #region simple evaluators
178      SimpleEvaluator trainingEvaluator = new SimpleEvaluator();
179      trainingEvaluator.Name = "TrainingEvaluator";
180      trainingEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
181      trainingEvaluator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
182      trainingEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
183      trainingEvaluator.GetVariableInfo("Values").ActualName = "TrainingValues";
184      trainingEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
185      SimpleEvaluator validationEvaluator = new SimpleEvaluator();
186      validationEvaluator.Name = "ValidationEvaluator";
187      validationEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
188      validationEvaluator.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
189      validationEvaluator.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
190      validationEvaluator.GetVariableInfo("Values").ActualName = "ValidationValues";
191      validationEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
192      SimpleEvaluator testEvaluator = new SimpleEvaluator();
193      testEvaluator.Name = "TestEvaluator";
194      testEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
195      testEvaluator.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
196      testEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
197      testEvaluator.GetVariableInfo("Values").ActualName = "TestValues";
198      testEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
199      seq.AddSubOperator(evaluatorInjector);
200      seq.AddSubOperator(trainingEvaluator);
201      seq.AddSubOperator(validationEvaluator);
202      seq.AddSubOperator(testEvaluator);
203      #endregion
204
205      #region variable impacts
206      // calculate and set variable impacts
207      VariableNamesExtractor namesExtractor = new VariableNamesExtractor();
208      namesExtractor.GetVariableInfo("VariableNames").ActualName = "InputVariableNames";
209      namesExtractor.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
210
211      PredictorBuilder predictorBuilder = new PredictorBuilder();
212      predictorBuilder.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
213      predictorBuilder.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
214
215      seq.AddSubOperator(namesExtractor);
216      seq.AddSubOperator(predictorBuilder);
217      VariableQualityImpactCalculator qualityImpactCalculator = new VariableQualityImpactCalculator();
218      qualityImpactCalculator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
219      qualityImpactCalculator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
220
221      seq.AddSubOperator(qualityImpactCalculator);
222      #endregion
223
224      seq.AddSubOperator(CreateModelAnalyzerOperator());
225
226
227
228
229      op.OperatorGraph.AddOperator(seq);
230      op.OperatorGraph.InitialOperator = seq;
231      return op;
232    }
233
234    protected virtual IOperator CreateModelAnalyzerOperator() {
235      return DefaultRegressionOperators.CreatePostProcessingOperator();
236    }
237
238    protected virtual IAnalyzerModel CreateLRModel(IScope bestModelScope) {
239      var model = new AnalyzerModel();
240      IGeneticProgrammingModel gpModel = bestModelScope.GetVariableValue<IGeneticProgrammingModel>("LinearRegressionModel", false);
241      model.SetMetaData("TreeSize", gpModel.Size);
242      model.SetMetaData("TreeHeight", gpModel.Height);
243      double treeComplexity = TreeComplexityEvaluator.Calculate(gpModel.FunctionTree);
244      model.SetMetaData("TreeComplexity", treeComplexity);
245      model.SetMetaData("AverageNodeComplexity", treeComplexity / gpModel.Size);
246
247      CreateSpecificLRModel(bestModelScope, model);
248      #region variable impacts
249      ItemList qualityImpacts = bestModelScope.GetVariableValue<ItemList>(ModelingResult.VariableQualityImpact.ToString(), false);
250      foreach (ItemList row in qualityImpacts) {
251        string variableName = ((StringData)row[0]).Data;
252        double impact = ((DoubleData)row[1]).Data;
253        model.SetVariableResult(ModelingResult.VariableQualityImpact, variableName, impact);
254        model.AddInputVariable(variableName);
255      }
256      #endregion
257      return model;
258    }
259
260    protected virtual void CreateSpecificLRModel(IScope bestModelScope, IAnalyzerModel model) {
261      DefaultRegressionOperators.PopulateAnalyzerModel(bestModelScope, model);
262    }
263
264    protected virtual IOperator GetMainOperator() {
265      CombinedOperator lr = (CombinedOperator)Engine.OperatorGraph.InitialOperator;
266      return lr.OperatorGraph.InitialOperator;
267    }
268
269    protected virtual IOperator GetVariableInjector() {
270      return GetMainOperator().SubOperators[0];
271    }
272
273    public override IView CreateView() {
274      return engine.CreateView();
275    }
276
277    #region IEditable Members
278
279    public virtual IEditor CreateEditor() {
280      return ((SequentialEngine.SequentialEngine)engine).CreateEditor();
281    }
282
283    #endregion
284
285    #region persistence
286    public override object Clone(IDictionary<Guid, object> clonedObjects) {
287      LinearRegression clone = (LinearRegression)base.Clone(clonedObjects);
288      clone.engine = (IEngine)Auxiliary.Clone(Engine, clonedObjects);
289      return clone;
290    }
291
292    public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
293      XmlNode node = base.GetXmlNode(name, document, persistedObjects);
294      node.AppendChild(PersistenceManager.Persist("Engine", engine, document, persistedObjects));
295      return node;
296    }
297
298    public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
299      base.Populate(node, restoredObjects);
300      engine = (IEngine)PersistenceManager.Restore(node.SelectSingleNode("Engine"), restoredObjects);
301    }
302    #endregion
303  }
304}
Note: See TracBrowser for help on using the repository browser.