source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelSurrogate.cs @ 17154

Last change on this file since 17154 was 17154, checked in by gkronber, 2 months ago

#2952: merged relevant revisions from branch to trunk

Merged revision(s) 17045-17153 from branches/2952_RF-ModelStorage/HeuristicLab.Algorithms.DataAnalysis:
#2952: Intermediate commit of refactoring RF models that is not yet finished.

........
#2952: Corrected evaluation in RF models.

........
#2952: Finished implementation of different RF models.

........
#2952 Fixed triggering model recalculation when cloning.
........
#2952: merged r17137 from trunk to branch
........
#2952: re-added backwards compatibility code for very old versions of GBT and RF
........
#2952: hide parameter in backwards compatibility hook
........

17045-17153

File size: 6.7 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2019 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
32  [Item("RandomForestModelSurrogate", "Represents a random forest for regression and classification.")]
33  public sealed class RandomForestModelSurrogate : ClassificationModel, IRandomForestModel {
34
35    #region parameters for recalculation of the model
36    [Storable]
37    private int seed;
38    [Storable]
39    private IDataAnalysisProblemData originalTrainingData;
40    [Storable]
41    private double[] classValues;
42    [Storable]
43    private int nTrees;
44    [Storable]
45    private double r;
46    [Storable]
47    private double m;
48    #endregion
49
50    // don't store the actual model!
51    // the actual model is only recalculated when necessary
52    private readonly Lazy<IRandomForestModel> actualModel;
53    private IRandomForestModel ActualModel {
54      get { return actualModel.Value; }
55    }
56
57    public int NumberOfTrees => ActualModel.NumberOfTrees;
58    public override IEnumerable<string> VariablesUsedForPrediction {
59      get { return ActualModel.VariablesUsedForPrediction; }
60    }
61
62    public RandomForestModelSurrogate(string targetVariable, IDataAnalysisProblemData originalTrainingData,
63      int seed, int nTrees, double r, double m, double[] classValues = null)
64      : base(targetVariable) {
65      this.name = ItemName;
66      this.description = ItemDescription;
67
68      // data which is necessary for recalculation of the model
69      this.seed = seed;
70      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
71      this.classValues = classValues;
72      this.nTrees = nTrees;
73      this.r = r;
74      this.m = m;
75
76      actualModel = new Lazy<IRandomForestModel>(() => RecalculateModel());
77    }
78
79    // wrap an actual model in a surrograte
80    public RandomForestModelSurrogate(IRandomForestModel model, string targetVariable, IDataAnalysisProblemData originalTrainingData,
81      int seed, int nTrees, double r, double m, double[] classValues = null) : this(targetVariable, originalTrainingData, seed, nTrees, r, m, classValues) {
82      actualModel = new Lazy<IRandomForestModel>(() => model);
83    }
84
85    [StorableConstructor]
86    private RandomForestModelSurrogate(StorableConstructorFlag _) : base(_) {
87      actualModel = new Lazy<IRandomForestModel>(() => RecalculateModel());
88    }
89
90    private RandomForestModelSurrogate(RandomForestModelSurrogate original, Cloner cloner) : base(original, cloner) {
91      IRandomForestModel clonedModel = null;
92      if (original.actualModel.IsValueCreated) clonedModel = cloner.Clone(original.ActualModel);
93      actualModel = new Lazy<IRandomForestModel>(CreateLazyInitFunc(clonedModel)); // only capture clonedModel in the closure
94
95      // clone data which is necessary to rebuild the model
96      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
97      this.seed = original.seed;
98      this.classValues = original.classValues;
99      this.nTrees = original.nTrees;
100      this.r = original.r;
101      this.m = original.m;
102    }
103
104    private Func<IRandomForestModel> CreateLazyInitFunc(IRandomForestModel clonedModel) {
105      return () => {
106        return clonedModel ?? RecalculateModel();
107      };
108    }
109
110    public override IDeepCloneable Clone(Cloner cloner) {
111      return new RandomForestModelSurrogate(this, cloner);
112    }
113
114    private IRandomForestModel RecalculateModel() {
115      IRandomForestModel randomForestModel = null;
116
117      double rmsError, oobRmsError, relClassError, oobRelClassError;
118      var classificationProblemData = originalTrainingData as IClassificationProblemData;
119
120      if (originalTrainingData is IRegressionProblemData regressionProblemData) {
121        randomForestModel = RandomForestRegression.CreateRandomForestRegressionModel(regressionProblemData,
122                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
123                                              out relClassError, out oobRelClassError);
124      } else if (classificationProblemData != null) {
125        randomForestModel = RandomForestClassification.CreateRandomForestClassificationModel(classificationProblemData,
126                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
127                                              out relClassError, out oobRelClassError);
128      }
129      return randomForestModel;
130    }
131
132    //RegressionModel methods
133    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
134      return ActualModel.IsProblemDataCompatible(problemData, out errorMessage);
135    }
136    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
137      return ActualModel.GetEstimatedValues(dataset, rows);
138    }
139    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
140      return ActualModel.GetEstimatedVariances(dataset, rows);
141    }
142    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
143      return new RandomForestRegressionSolution(this, (IRegressionProblemData)problemData.Clone());
144    }
145
146    //ClassificationModel methods
147    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
148      return ActualModel.GetEstimatedClassValues(dataset, rows);
149    }
150    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
151      return new RandomForestClassificationSolution(this, (IClassificationProblemData)problemData.Clone());
152    }
153
154    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
155      return ActualModel.ExtractTree(treeIdx);
156    }
157  }
158}
Note: See TracBrowser for help on using the repository browser.