source: trunk/HeuristicLab.Algorithms.DataAnalysis/3.4/RandomForest/RandomForestModelSurrogate.cs @ 17873

Last change on this file since 17873 was 17873, checked in by mkommend, 7 months ago

#3113: Added overload of IsProblemDataComptabile to SurrogateRFModel to handle the compatiblity check correctly.

File size: 6.9 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using HEAL.Attic;
25using HeuristicLab.Common;
26using HeuristicLab.Core;
27using HeuristicLab.Encodings.SymbolicExpressionTreeEncoding;
28using HeuristicLab.Problems.DataAnalysis;
29
30namespace HeuristicLab.Algorithms.DataAnalysis {
31  [StorableType("A4F688CD-1F42-4103-8449-7DE52AEF6C69")]
32  [Item("RandomForestModelSurrogate", "Represents a random forest for regression and classification.")]
33  public sealed class RandomForestModelSurrogate : ClassificationModel, IRandomForestModel {
34
35    #region parameters for recalculation of the model
36    [Storable]
37    private readonly int seed;
38    [Storable]
39    private readonly IDataAnalysisProblemData originalTrainingData;
40    [Storable]
41    private readonly double[] classValues;
42    [Storable]
43    private readonly int nTrees;
44    [Storable]
45    private readonly double r;
46    [Storable]
47    private readonly double m;
48    #endregion
49
50
51    // don't store the actual model!
52    // the actual model is only recalculated when necessary
53    private IRandomForestModel fullModel = null;
54    private readonly Lazy<IRandomForestModel> actualModel;
55
56    private IRandomForestModel ActualModel {
57      get { return actualModel.Value; }
58    }
59
60    public int NumberOfTrees => ActualModel.NumberOfTrees;
61    public override IEnumerable<string> VariablesUsedForPrediction {
62      get { return ActualModel.VariablesUsedForPrediction; }
63    }
64
65    public RandomForestModelSurrogate(string targetVariable, IDataAnalysisProblemData originalTrainingData,
66      int seed, int nTrees, double r, double m, double[] classValues = null)
67      : base(targetVariable) {
68      this.name = ItemName;
69      this.description = ItemDescription;
70
71      // data which is necessary for recalculation of the model
72      this.seed = seed;
73      this.originalTrainingData = (IDataAnalysisProblemData)originalTrainingData.Clone();
74      this.classValues = classValues;
75      this.nTrees = nTrees;
76      this.r = r;
77      this.m = m;
78
79      actualModel = CreateLazyInitFunc();
80    }
81
82    // wrap an actual model in a surrogate
83    public RandomForestModelSurrogate(IRandomForestModel model, string targetVariable, IDataAnalysisProblemData originalTrainingData,
84      int seed, int nTrees, double r, double m, double[] classValues = null)
85      : this(targetVariable, originalTrainingData, seed, nTrees, r, m, classValues) {
86      fullModel = model;
87    }
88
89    [StorableConstructor]
90    private RandomForestModelSurrogate(StorableConstructorFlag _) : base(_) {
91      actualModel = CreateLazyInitFunc();
92    }
93
94    private RandomForestModelSurrogate(RandomForestModelSurrogate original, Cloner cloner) : base(original, cloner) {
95      // clone data which is necessary to rebuild the model
96      this.originalTrainingData = cloner.Clone(original.originalTrainingData);
97      this.seed = original.seed;
98      this.classValues = original.classValues;
99      this.nTrees = original.nTrees;
100      this.r = original.r;
101      this.m = original.m;
102
103      // clone full model if it has already been created
104      if (original.fullModel != null) this.fullModel = cloner.Clone(original.fullModel);
105      actualModel = CreateLazyInitFunc();
106    }
107
108    public override IDeepCloneable Clone(Cloner cloner) {
109      return new RandomForestModelSurrogate(this, cloner);
110    }
111
112    private Lazy<IRandomForestModel> CreateLazyInitFunc() {
113      return new Lazy<IRandomForestModel>(() => {
114        if (fullModel == null) fullModel = RecalculateModel();
115        return fullModel;
116      });
117    }
118
119    private IRandomForestModel RecalculateModel() {
120      IRandomForestModel randomForestModel = null;
121
122      double rmsError, oobRmsError, relClassError, oobRelClassError;
123      var classificationProblemData = originalTrainingData as IClassificationProblemData;
124
125      if (originalTrainingData is IRegressionProblemData regressionProblemData) {
126        randomForestModel = RandomForestRegression.CreateRandomForestRegressionModel(regressionProblemData,
127                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
128                                              out relClassError, out oobRelClassError);
129      } else if (classificationProblemData != null) {
130        randomForestModel = RandomForestClassification.CreateRandomForestClassificationModel(classificationProblemData,
131                                              nTrees, r, m, seed, out rmsError, out oobRmsError,
132                                              out relClassError, out oobRelClassError);
133      }
134      return randomForestModel;
135    }
136
137    public override bool IsProblemDataCompatible(IDataAnalysisProblemData problemData, out string errorMessage) {
138      return ActualModel.IsProblemDataCompatible(problemData, out errorMessage);
139    }
140
141    //RegressionModel methods
142    public bool IsProblemDataCompatible(IRegressionProblemData problemData, out string errorMessage) {
143      return ActualModel.IsProblemDataCompatible(problemData, out errorMessage);
144    }
145    public IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
146      return ActualModel.GetEstimatedValues(dataset, rows);
147    }
148    public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
149      return ActualModel.GetEstimatedVariances(dataset, rows);
150    }
151    public IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
152      return new RandomForestRegressionSolution(this, (IRegressionProblemData)problemData.Clone());
153    }
154
155    //ClassificationModel methods
156    public override IEnumerable<double> GetEstimatedClassValues(IDataset dataset, IEnumerable<int> rows) {
157      return ActualModel.GetEstimatedClassValues(dataset, rows);
158    }
159    public override IClassificationSolution CreateClassificationSolution(IClassificationProblemData problemData) {
160      return new RandomForestClassificationSolution(this, (IClassificationProblemData)problemData.Clone());
161    }
162
163    public ISymbolicExpressionTree ExtractTree(int treeIdx) {
164      return ActualModel.ExtractTree(treeIdx);
165    }
166  }
167}
Note: See TracBrowser for help on using the repository browser.