Free cookie consent management tool by TermsFeed Policy Generator

source: trunk/sources/HeuristicLab.LinearRegression/3.2/LinearRegression.cs @ 2636

Last change on this file since 2636 was 2569, checked in by gkronber, 15 years ago

Added calculation of tree complexity to LR models. #821

File size: 13.3 KB
RevLine 
[2154]1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2008 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Text;
26using HeuristicLab.Core;
27using System.Xml;
28using System.Diagnostics;
29using HeuristicLab.DataAnalysis;
30using HeuristicLab.Data;
31using HeuristicLab.Operators;
32using HeuristicLab.GP.StructureIdentification;
33using HeuristicLab.Modeling;
34using HeuristicLab.GP;
[2161]35using HeuristicLab.Random;
[2222]36using HeuristicLab.GP.Interfaces;
[2154]37
38namespace HeuristicLab.LinearRegression {
39  public class LinearRegression : ItemBase, IEditable, IAlgorithm {
40
[2353]41    public virtual string Name { get { return "LinearRegression"; } }
42    public virtual string Description { get { return "TODO"; } }
[2154]43
[2377]44    private IEngine engine;
[2353]45    public virtual IEngine Engine {
[2154]46      get { return engine; }
47    }
48
[2353]49    public virtual Dataset Dataset {
[2154]50      get { return ProblemInjector.GetVariableValue<Dataset>("Dataset", null, false); }
51      set { ProblemInjector.GetVariable("Dataset").Value = value; }
52    }
53
[2440]54    public virtual string TargetVariable {
55      get { return ProblemInjector.GetVariableValue<StringData>("TargetVariable", null, false).Data; }
56      set { ProblemInjector.GetVariableValue<StringData>("TargetVariable", null, false).Data = value; }
[2154]57    }
58
[2353]59    public virtual IOperator ProblemInjector {
[2154]60      get {
61        IOperator main = GetMainOperator();
[2363]62        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
63        return probInjector.OperatorGraph.InitialOperator.SubOperators[0];
[2154]64      }
65      set {
66        IOperator main = GetMainOperator();
[2363]67        CombinedOperator probInjector = (CombinedOperator)main.SubOperators[2];
68        probInjector.OperatorGraph.InitialOperator.RemoveSubOperator(0);
69        probInjector.OperatorGraph.InitialOperator.AddSubOperator(value, 0);
[2154]70      }
71    }
[2440]72    public IEnumerable<string> AllowedVariables {
[2375]73      get {
[2440]74        ItemList<StringData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<StringData>>("AllowedFeatures", null, false);
[2375]75        return allowedVariables.Select(x => x.Data);
76      }
77      set {
[2440]78        ItemList<StringData> allowedVariables = ProblemInjector.GetVariableValue<ItemList<StringData>>("AllowedFeatures", null, false);
79        foreach (string x in value) allowedVariables.Add(new StringData(x));
[2375]80      }
81    }
[2154]82
[2375]83    public int TrainingSamplesStart {
84      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data; }
85      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesStart", null, false).Data = value; }
86    }
87
88    public int TrainingSamplesEnd {
89      get { return ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data; }
90      set { ProblemInjector.GetVariableValue<IntData>("TrainingSamplesEnd", null, false).Data = value; }
91    }
92
93    public int ValidationSamplesStart {
94      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data; }
95      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesStart", null, false).Data = value; }
96    }
97
98    public int ValidationSamplesEnd {
99      get { return ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data; }
100      set { ProblemInjector.GetVariableValue<IntData>("ValidationSamplesEnd", null, false).Data = value; }
101    }
102
103    public int TestSamplesStart {
104      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data; }
105      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesStart", null, false).Data = value; }
106    }
107
108    public int TestSamplesEnd {
109      get { return ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data; }
110      set { ProblemInjector.GetVariableValue<IntData>("TestSamplesEnd", null, false).Data = value; }
111    }
112
[2353]113    public virtual IAnalyzerModel Model {
[2154]114      get {
115        if (!engine.Terminated) throw new InvalidOperationException("The algorithm is still running. Wait until the algorithm is terminated to retrieve the result.");
116        IScope bestModelScope = engine.GlobalScope;
117        return CreateLRModel(bestModelScope);
118      }
119    }
120
121    public LinearRegression() {
122      engine = new SequentialEngine.SequentialEngine();
123      CombinedOperator algo = CreateAlgorithm();
124      engine.OperatorGraph.AddOperator(algo);
125      engine.OperatorGraph.InitialOperator = algo;
126    }
127
[2353]128    protected virtual CombinedOperator CreateAlgorithm() {
[2154]129      CombinedOperator algo = new CombinedOperator();
130      SequentialProcessor seq = new SequentialProcessor();
[2356]131      algo.Name = Name;
132      seq.Name = Name;
[2154]133
134      IOperator globalInjector = CreateGlobalInjector();
[2161]135
[2356]136      HL3TreeEvaluatorInjector treeEvaluatorInjector = new HL3TreeEvaluatorInjector();
[2328]137
[2154]138      LinearRegressionOperator lrOperator = new LinearRegressionOperator();
[2161]139      lrOperator.GetVariableInfo("SamplesStart").ActualName = "ActualTrainingSamplesStart";
140      lrOperator.GetVariableInfo("SamplesEnd").ActualName = "ActualTrainingSamplesEnd";
[2154]141
142      seq.AddSubOperator(globalInjector);
[2356]143      seq.AddSubOperator(new RandomInjector());
144      seq.AddSubOperator(CreateProblemInjector());
[2328]145      seq.AddSubOperator(treeEvaluatorInjector);
[2154]146      seq.AddSubOperator(lrOperator);
[2356]147      seq.AddSubOperator(CreatePostProcessingOperator());
[2154]148
149      algo.OperatorGraph.InitialOperator = seq;
150      algo.OperatorGraph.AddOperator(seq);
151
152      return algo;
153    }
154
[2356]155    protected virtual IOperator CreateProblemInjector() {
156      return DefaultRegressionOperators.CreateProblemInjector();
157    }
158
[2353]159    protected virtual VariableInjector CreateGlobalInjector() {
[2154]160      VariableInjector injector = new VariableInjector();
[2328]161      injector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000)));
[2154]162      injector.AddVariable(new HeuristicLab.Core.Variable("TotalEvaluatedNodes", new DoubleData(0)));
[2419]163      injector.AddVariable(new HeuristicLab.Core.Variable("MaxNumberOfTrainingSamples", new IntData(4000)));
[2154]164
165      return injector;
166    }
167
[2356]168    protected virtual IOperator CreatePostProcessingOperator() {
[2353]169      CombinedOperator op = new CombinedOperator();
170      op.Name = "Model Analyzer";
[2270]171
[2356]172      SequentialProcessor seq = new SequentialProcessor();
173      HL3TreeEvaluatorInjector evaluatorInjector = new HL3TreeEvaluatorInjector();
174      evaluatorInjector.AddVariable(new HeuristicLab.Core.Variable("PunishmentFactor", new DoubleData(1000.0)));
175      evaluatorInjector.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
176
177      #region simple evaluators
178      SimpleEvaluator trainingEvaluator = new SimpleEvaluator();
179      trainingEvaluator.Name = "TrainingEvaluator";
180      trainingEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
181      trainingEvaluator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
182      trainingEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
183      trainingEvaluator.GetVariableInfo("Values").ActualName = "TrainingValues";
184      trainingEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
185      SimpleEvaluator validationEvaluator = new SimpleEvaluator();
186      validationEvaluator.Name = "ValidationEvaluator";
187      validationEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
188      validationEvaluator.GetVariableInfo("SamplesStart").ActualName = "ValidationSamplesStart";
189      validationEvaluator.GetVariableInfo("SamplesEnd").ActualName = "ValidationSamplesEnd";
190      validationEvaluator.GetVariableInfo("Values").ActualName = "ValidationValues";
191      validationEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
192      SimpleEvaluator testEvaluator = new SimpleEvaluator();
193      testEvaluator.Name = "TestEvaluator";
194      testEvaluator.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
195      testEvaluator.GetVariableInfo("SamplesStart").ActualName = "TestSamplesStart";
196      testEvaluator.GetVariableInfo("SamplesEnd").ActualName = "TestSamplesEnd";
197      testEvaluator.GetVariableInfo("Values").ActualName = "TestValues";
198      testEvaluator.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
199      seq.AddSubOperator(evaluatorInjector);
200      seq.AddSubOperator(trainingEvaluator);
201      seq.AddSubOperator(validationEvaluator);
202      seq.AddSubOperator(testEvaluator);
203      #endregion
204
205      #region variable impacts
206      // calculate and set variable impacts
207      VariableNamesExtractor namesExtractor = new VariableNamesExtractor();
208      namesExtractor.GetVariableInfo("VariableNames").ActualName = "InputVariableNames";
209      namesExtractor.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
210
211      PredictorBuilder predictorBuilder = new PredictorBuilder();
212      predictorBuilder.GetVariableInfo("TreeEvaluator").ActualName = "ModelAnalysisTreeEvaluator";
213      predictorBuilder.GetVariableInfo("FunctionTree").ActualName = "LinearRegressionModel";
214
215      seq.AddSubOperator(namesExtractor);
216      seq.AddSubOperator(predictorBuilder);
[2454]217      VariableQualityImpactCalculator qualityImpactCalculator = new VariableQualityImpactCalculator();
218      qualityImpactCalculator.GetVariableInfo("SamplesStart").ActualName = "TrainingSamplesStart";
219      qualityImpactCalculator.GetVariableInfo("SamplesEnd").ActualName = "TrainingSamplesEnd";
220
221      seq.AddSubOperator(qualityImpactCalculator);
[2356]222      #endregion
223
224      seq.AddSubOperator(CreateModelAnalyzerOperator());
225
[2454]226
227
228
[2356]229      op.OperatorGraph.AddOperator(seq);
230      op.OperatorGraph.InitialOperator = seq;
[2353]231      return op;
[2154]232    }
233
[2356]234    protected virtual IOperator CreateModelAnalyzerOperator() {
235      return DefaultRegressionOperators.CreatePostProcessingOperator();
236    }
237
[2353]238    protected virtual IAnalyzerModel CreateLRModel(IScope bestModelScope) {
[2356]239      var model = new AnalyzerModel();
[2561]240      IGeneticProgrammingModel gpModel = bestModelScope.GetVariableValue<IGeneticProgrammingModel>("LinearRegressionModel", false);
241      model.SetMetaData("TreeSize", gpModel.Size);
242      model.SetMetaData("TreeHeight", gpModel.Height);
[2569]243      double treeComplexity = TreeComplexityEvaluator.Calculate(gpModel.FunctionTree);
244      model.SetMetaData("TreeComplexity", treeComplexity);
245      model.SetMetaData("AverageNodeComplexity", treeComplexity / gpModel.Size);
246
[2454]247      CreateSpecificLRModel(bestModelScope, model);
248      #region variable impacts
249      ItemList qualityImpacts = bestModelScope.GetVariableValue<ItemList>(ModelingResult.VariableQualityImpact.ToString(), false);
250      foreach (ItemList row in qualityImpacts) {
251        string variableName = ((StringData)row[0]).Data;
252        double impact = ((DoubleData)row[1]).Data;
253        model.SetVariableResult(ModelingResult.VariableQualityImpact, variableName, impact);
254        model.AddInputVariable(variableName);
255      }
256      #endregion
[2356]257      return model;
[2154]258    }
259
[2454]260    protected virtual void CreateSpecificLRModel(IScope bestModelScope, IAnalyzerModel model) {
261      DefaultRegressionOperators.PopulateAnalyzerModel(bestModelScope, model);
262    }
263
[2353]264    protected virtual IOperator GetMainOperator() {
[2154]265      CombinedOperator lr = (CombinedOperator)Engine.OperatorGraph.InitialOperator;
266      return lr.OperatorGraph.InitialOperator;
267    }
268
[2360]269    protected virtual IOperator GetVariableInjector() {
270      return GetMainOperator().SubOperators[0];
271    }
272
[2154]273    public override IView CreateView() {
274      return engine.CreateView();
275    }
276
277    #region IEditable Members
278
[2353]279    public virtual IEditor CreateEditor() {
[2377]280      return ((SequentialEngine.SequentialEngine)engine).CreateEditor();
[2154]281    }
282
283    #endregion
[2377]284
285    #region persistence
286    public override object Clone(IDictionary<Guid, object> clonedObjects) {
[2454]287      LinearRegression clone = (LinearRegression)base.Clone(clonedObjects);
[2377]288      clone.engine = (IEngine)Auxiliary.Clone(Engine, clonedObjects);
289      return clone;
290    }
291
292    public override XmlNode GetXmlNode(string name, XmlDocument document, IDictionary<Guid, IStorable> persistedObjects) {
293      XmlNode node = base.GetXmlNode(name, document, persistedObjects);
294      node.AppendChild(PersistenceManager.Persist("Engine", engine, document, persistedObjects));
295      return node;
296    }
297
298    public override void Populate(XmlNode node, IDictionary<Guid, IStorable> restoredObjects) {
299      base.Populate(node, restoredObjects);
300      engine = (IEngine)PersistenceManager.Restore(node.SelectSingleNode("Engine"), restoredObjects);
301    }
302    #endregion
[2154]303  }
304}
Note: See TracBrowser for help on using the repository browser.