Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.LinearRegression/3.2/LinearRegression.cs @ 2377

Last change on this file since 2377 was 2377, checked in by gkronber, 13 years ago
  • Implemented cloning and persistence in data-modeling algorithms.
  • Fixed bugs in CEDMA controller.

#754

File size: 11.8 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30using HeuristicLab.Data;
31using HeuristicLab.Operators;
32using HeuristicLab.GP.StructureIdentification;
33using HeuristicLab.Modeling;
34using HeuristicLab.GP;
35using HeuristicLab.Random;
36using HeuristicLab.GP.Interfaces;
37
38namespace HeuristicLab.LinearRegression {
39  public class LinearRegression : ItemBase, IEditable, IAlgorithm {
40
41    public virtual string Name { get { return "LinearRegression"; } }
42    public virtual string Description { get { return "TODO"; } }
43
44    private IEngine engine;
45    public virtual IEngine Engine {
46      get { return engine; }
47    }
48
49    public virtual Dataset Dataset {
50      get { return ProblemInjector.GetVariableValue<Dataset>("Dataset", null, false); }
51      set { ProblemInjector.GetVariable("Dataset").Value = value; }
52    }
53
54    public virtual int TargetVariable {
55      get { return ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data; }
56      set { ProblemInjector.GetVariableValue<IntData>("TargetVariable", null, false).Data = value; }
57    }
58
59    public virtual IOperator ProblemInjector {
60      get {
61        IOperator main = GetMainOperator();
62        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
63        return probInjector.OperatorGraph.InitialOperator.SubOperators[0];
64      }
65      set {
66        IOperator main = GetMainOperator();
67        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
68        probInjector.OperatorGraph.InitialOperator.RemoveSubOperator(0);
69        probInjector.OperatorGraph.InitialOperator.AddSubOperator(value, 0);
70      }
71    }
72    public IEnumerable<int> AllowedVariables {
73      get {
74        ItemList<IntData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<IntData>>("AllowedFeatures", null, false);
75        return allowedVariables.Select(x => x.Data);
76      }
77      set {
78        ItemList<IntData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<IntData>>("AllowedFeatures", null, false);
79        foreach (int x in value) allowedVariables.Add(new IntData(x));
80      }
81    }
82
83    public int TrainingSamplesStart {
84      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data; }
85      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data = value; }
86    }
87
88    public int TrainingSamplesEnd {
89      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data; }
90      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data = value; }
91    }
92
93    public int ValidationSamplesStart {
94      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data; }
95      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data = value; }
96    }
97
98    public int ValidationSamplesEnd {
99      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data; }
100      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data = value; }
101    }
102
103    public int TestSamplesStart {
104      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data; }
105      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data = value; }
106    }
107
108    public int TestSamplesEnd {
109      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data; }
110      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data = value; }
111    }
112
113    public virtual IAnalyzerModel Model {
114      get {
115        if (!engine.Terminated) throw new InvalidOperationException("The algorithm is still running. Wait until the algorithm is terminated to retrieve the result.");
116        IScope bestModelScope = engine.GlobalScope;
117        return CreateLRModel(bestModelScope);
118      }
119    }
120
121    public LinearRegression() {
122      engine = new SequentialEngine.SequentialEngine();
123      CombinedOperator algo = CreateAlgorithm();
124      engine.OperatorGraph.AddOperator(algo);
125      engine.OperatorGraph.InitialOperator = algo;
126    }
127
128    protected virtual CombinedOperator CreateAlgorithm() {
129      CombinedOperator algo = new CombinedOperator();
130      SequentialProcessor seq = new SequentialProcessor();
131      algo.Name = Name;
132      seq.Name = Name;
133
134      IOperator globalInjector = CreateGlobalInjector();
135
136      HL3TreeEvaluatorInjector treeEvaluatorInjector = new HL3TreeEvaluatorInjector();
137
138      LinearRegressionOperator lrOperator = new LinearRegressionOperator();
139      lrOperator.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
140      lrOperator.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
141
142      seq.AddSubOperator(globalInjector);
143      seq.AddSubOperator(new RandomInjector());
144      seq.AddSubOperator(CreateProblemInjector());
145      seq.AddSubOperator(treeEvaluatorInjector);
146      seq.AddSubOperator(lrOperator);
147      seq.AddSubOperator(CreatePostProcessingOperator());
148
149      algo.OperatorGraph.InitialOperator = seq;
150      algo.OperatorGraph.AddOperator(seq);
151
152      return algo;
153    }
154
155    protected virtual IOperator CreateProblemInjector() {
156      return DefaultRegressionOperators.CreateProblemInjector();
157    }
158
159    protected virtual VariableInjector CreateGlobalInjector() {
160      VariableInjector injector = new VariableInjector();
161      injector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000)));
162      injector.AddVariable(new HeuristicLab.Core.Variable("TotalEvaluatedNodes", new DoubleData(0)));
163
164      return injector;
165    }
166
167    protected virtual IOperator CreatePostProcessingOperator() {
168      CombinedOperator op = new CombinedOperator();
169      op.Name = "Model Analyzer";
170
171      SequentialProcessor seq = new SequentialProcessor();
172      HL3TreeEvaluatorInjector evaluatorInjector = new HL3TreeEvaluatorInjector();
173      evaluatorInjector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000.0)));
174      evaluatorInjector.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
175
176      #region simple evaluators
177      SimpleEvaluator trainingEvaluator = new SimpleEvaluator();
178      trainingEvaluator.Name = "TrainingEvaluator";
179      trainingEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
180      trainingEvaluator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
181      trainingEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
182      trainingEvaluator.GetVariableInfo("Values").ActualName = "TrainingValues";
183      trainingEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
184      SimpleEvaluator validationEvaluator = new SimpleEvaluator();
185      validationEvaluator.Name = "ValidationEvaluator";
186      validationEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
187      validationEvaluator.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
188      validationEvaluator.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
189      validationEvaluator.GetVariableInfo("Values").ActualName = "ValidationValues";
190      validationEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
191      SimpleEvaluator testEvaluator = new SimpleEvaluator();
192      testEvaluator.Name = "TestEvaluator";
193      testEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
194      testEvaluator.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
195      testEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
196      testEvaluator.GetVariableInfo("Values").ActualName = "TestValues";
197      testEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
198      seq.AddSubOperator(evaluatorInjector);
199      seq.AddSubOperator(trainingEvaluator);
200      seq.AddSubOperator(validationEvaluator);
201      seq.AddSubOperator(testEvaluator);
202      #endregion
203
204      #region variable impacts
205      // calculate and set variable impacts
206      VariableNamesExtractor namesExtractor = new VariableNamesExtractor();
207      namesExtractor.GetVariableInfo("VariableNames").ActualName = "InputVariableNames";
208      namesExtractor.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
209
210      PredictorBuilder predictorBuilder = new PredictorBuilder();
211      predictorBuilder.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
212      predictorBuilder.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
213
214      seq.AddSubOperator(namesExtractor);
215      seq.AddSubOperator(predictorBuilder);
216      #endregion
217
218      seq.AddSubOperator(CreateModelAnalyzerOperator());
219
220      op.OperatorGraph.AddOperator(seq);
221      op.OperatorGraph.InitialOperator = seq;
222      return op;
223    }
224
225    protected virtual IOperator CreateModelAnalyzerOperator() {
226      return DefaultRegressionOperators.CreatePostProcessingOperator();
227    }
228
229    protected virtual IAnalyzerModel CreateLRModel(IScope bestModelScope) {
230      var model = new AnalyzerModel();
231      DefaultRegressionOperators.PopulateAnalyzerModel(bestModelScope, model);
232      return model;
233    }
234
235    protected virtual IOperator GetMainOperator() {
236      CombinedOperator lr = (CombinedOperator)Engine.OperatorGraph.InitialOperator;
237      return lr.OperatorGraph.InitialOperator;
238    }
239
240    protected virtual IOperator GetVariableInjector() {
241      return GetMainOperator().SubOperators[0];
242    }
243
244    public override IView CreateView() {
245      return engine.CreateView();
246    }
247
248    #region IEditable Members
249
250    public virtual IEditor CreateEditor() {
251      return ((SequentialEngine.SequentialEngine)engine).CreateEditor();
252    }
253
254    #endregion
255
256    #region persistence
257    public override object Clone(IDictionary<Guid, object> clonedObjects) {
258      LinearRegression clone = (LinearRegression) base.Clone(clonedObjects);
259      clone.engine = (IEngine)Auxiliary.Clone(Engine, clonedObjects);
260      return clone;
261    }
262
263    public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
264      XmlNode node = base.GetXmlNode(name, document, persistedObjects);
265      node.AppendChild(PersistenceManager.Persist("Engine", engine, document, persistedObjects));
266      return node;
267    }
268
269    public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
270      base.Populate(node, restoredObjects);
271      engine = (IEngine)PersistenceManager.Restore(node.SelectSingleNode("Engine"), restoredObjects);
272    }
273    #endregion
274  }
275}
Note: See TracBrowser for help on using the repository browser.