Free cookie consent management tool by TermsFeed Policy Generator

source: branches/M5Regression/HeuristicLab.Algorithms.DataAnalysis/3.4/M5Regression/MetaModels/M5NodeModel.cs @ 15470

Last change on this file since 15470 was 15470, checked in by bwerth, 6 years ago

#2847 worked on M5Regression

File size: 11.6 KB
Line 
1#region License Information
2/* HeuristicLab
3 * Copyright (C) 2002-2017 Heuristic and Evolutionary Algorithms Laboratory (HEAL)
4 *
5 * This file is part of HeuristicLab.
6 *
7 * HeuristicLab is free software: you can redistribute it and/or modify
8 * it under the terms of the GNU General Public License as published by
9 * the Free Software Foundation, either version 3 of the License, or
10 * (at your option) any later version.
11 *
12 * HeuristicLab is distributed in the hope that it will be useful,
13 * but WITHOUT ANY WARRANTY; without even the implied warranty of
14 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 * GNU General Public License for more details.
16 *
17 * You should have received a copy of the GNU General Public License
18 * along with HeuristicLab. If not, see <http://www.gnu.org/licenses/>.
19 */
20#endregion
21
22using System;
23using System.Collections.Generic;
24using System.Linq;
25using System.Threading;
26using HeuristicLab.Common;
27using HeuristicLab.Core;
28using HeuristicLab.Data;
29using HeuristicLab.Persistence.Default.CompositeSerializers.Storable;
30using HeuristicLab.Problems.DataAnalysis;
31
32namespace HeuristicLab.Algorithms.DataAnalysis {
33  [StorableClass]
34  internal class M5NodeModel : RegressionModel {
35    #region Properties
36    [Storable]
37    internal bool IsLeaf { get; private set; }
38    [Storable]
39    internal IRegressionModel NodeModel { get; private set; }
40    [Storable]
41    internal string SplitAttr { get; private set; }
42    [Storable]
43    internal double SplitValue { get; private set; }
44    [Storable]
45    internal M5NodeModel Left { get; private set; }
46    [Storable]
47    internal M5NodeModel Right { get; private set; }
48    [Storable]
49    internal M5NodeModel Parent { get; set; }
50    [Storable]
51    internal int NumSamples { get; private set; }
52    [Storable]
53    internal int NumParam { get; set; }
54    [Storable]
55    internal int NodeModelParams { get; set; }
56    [Storable]
57    private IReadOnlyList<string> Variables { get; set; }
58    #endregion
59
60    #region HLConstructors
61    [StorableConstructor]
62    protected M5NodeModel(bool deserializing) : base(deserializing) { }
63    protected M5NodeModel(M5NodeModel original, Cloner cloner) : base(original, cloner) {
64      IsLeaf = original.IsLeaf;
65      NodeModel = cloner.Clone(original.NodeModel);
66      SplitValue = original.SplitValue;
67      SplitAttr = original.SplitAttr;
68      Left = cloner.Clone(original.Left);
69      Right = cloner.Clone(original.Right);
70      Parent = cloner.Clone(original.Parent);
71      NumParam = original.NumParam;
72      NumSamples = original.NumSamples;
73      Variables = original.Variables != null ? original.Variables.ToList() : null;
74    }
75    protected M5NodeModel(string targetAttr) : base(targetAttr) { }
76    protected M5NodeModel(M5NodeModel parent) : base(parent.TargetVariable) {
77      Parent = parent;
78    }
79    public override IDeepCloneable Clone(Cloner cloner) {
80      return new M5NodeModel(this, cloner);
81    }
82    public static M5NodeModel CreateNode(string targetAttr, M5CreationParameters m5CreationParams) {
83      return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(targetAttr) : new M5NodeModel(targetAttr);
84    }
85    private static M5NodeModel CreateNode(M5NodeModel parent, M5CreationParameters m5CreationParams) {
86      return m5CreationParams.LeafType is ILeafType<IConfidenceRegressionModel> ? new ConfidenceM5NodeModel(parent) : new M5NodeModel(parent);
87    }
88    #endregion
89
90    #region RegressionModel
91    public override IEnumerable<string> VariablesUsedForPrediction {
92      get { return Variables; }
93    }
94    public override IEnumerable<double> GetEstimatedValues(IDataset dataset, IEnumerable<int> rows) {
95      if (!IsLeaf) return rows.Select(row => GetEstimatedValue(dataset, row));
96      if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
97      return NodeModel.GetEstimatedValues(dataset, rows);
98    }
99    public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
100      return new RegressionSolution(this, problemData);
101    }
102    #endregion
103
104    internal void Split(IReadOnlyList<int> rows, M5CreationParameters m5CreationParams, double globalStdDev) {
105      Variables = m5CreationParams.AllowedInputVariables.ToArray();
106      NumSamples = rows.Count;
107      Right = null;
108      Left = null;
109      NodeModel = null;
110      SplitAttr = null;
111      SplitValue = double.NaN;
112      string attr;
113      double splitValue;
114      //IsLeaf = m5CreationParams.Data.GetDoubleValues(TargetVariable, rows).StandardDeviation() < globalStdDev * DevFraction;
115      //if (IsLeaf) return;
116      IsLeaf = !m5CreationParams.Split.Split(new RegressionProblemData(ReduceDataset(m5CreationParams.Data, rows), Variables, TargetVariable), m5CreationParams.MinLeafSize, out attr, out splitValue);
117      if (IsLeaf) return;
118
119      //split Dataset
120      IReadOnlyList<int> leftRows, rightRows;
121      SplitRows(rows, m5CreationParams.Data, attr, splitValue, out leftRows, out rightRows);
122
123      if (leftRows.Count < m5CreationParams.MinLeafSize || rightRows.Count < m5CreationParams.MinLeafSize) {
124        IsLeaf = true;
125        return;
126      }
127      SplitAttr = attr;
128      SplitValue = splitValue;
129
130      //create subtrees
131      Left = CreateNode(this, m5CreationParams);
132      Left.Split(leftRows, m5CreationParams, globalStdDev);
133      Right = CreateNode(this, m5CreationParams);
134      Right.Split(rightRows, m5CreationParams, globalStdDev);
135    }
136
137    internal bool Prune(IReadOnlyList<int> trainingRows, IReadOnlyList<int> testRows, M5CreationParameters m5CreationParams, CancellationToken cancellation, double globalStdDev) {
138      if (IsLeaf) {
139        BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
140        NumParam = NodeModelParams;
141        return true;
142      }
143      //split training & holdout data
144      IReadOnlyList<int> leftTest, rightTest;
145      SplitRows(testRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTest, out rightTest);
146      IReadOnlyList<int> leftTraining, rightTraining;
147      SplitRows(trainingRows, m5CreationParams.Data, SplitAttr, SplitValue, out leftTraining, out rightTraining);
148
149      //prune children frist
150      var lpruned = Left.Prune(leftTraining, leftTest, m5CreationParams, cancellation, globalStdDev);
151      var rpruned = Right.Prune(rightTraining, rightTest, m5CreationParams, cancellation, globalStdDev);
152      NumParam = Left.NumParam + Right.NumParam + 1;
153
154      //TODO check if this reduces quality. It reduces training effort (consideraby for some pruningTypes)
155      if (!lpruned && !rpruned) return false;
156
157      BuildModel(trainingRows, m5CreationParams.Data, m5CreationParams.Random, m5CreationParams.PruningLeaf, cancellation);
158
159      //check if children will be pruned
160      if (!((PruningBase) m5CreationParams.Pruningtype).Prune(this, m5CreationParams, testRows, globalStdDev)) return false;
161
162      //convert to leafNode
163      ((IntValue) m5CreationParams.Results[M5RuleModel.NoCurrentLeafesResultName].Value).Value -= EnumerateNodes().Count(x => x.IsLeaf) - 1;
164      IsLeaf = true;
165      Right = null;
166      Left = null;
167      NumParam = NodeModelParams;
168      return true;
169    }
170
171    internal void InstallModels(IReadOnlyList<int> rows, IRandom random, IDataset data, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
172      if (!IsLeaf) {
173        IReadOnlyList<int> leftRows, rightRows;
174        SplitRows(rows, data, SplitAttr, SplitValue, out leftRows, out rightRows);
175        Left.InstallModels(leftRows, random, data, leafType, cancellation);
176        Right.InstallModels(rightRows, random, data, leafType, cancellation);
177        return;
178      }
179      BuildModel(rows, data, random, leafType, cancellation);
180    }
181
182    internal IEnumerable<M5NodeModel> EnumerateNodes() {
183      var queue = new Queue<M5NodeModel>();
184      queue.Enqueue(this);
185      while (queue.Count != 0) {
186        var cur = queue.Dequeue();
187        yield return cur;
188        if (cur.Left == null && cur.Right == null) continue;
189        if (cur.Left != null) queue.Enqueue(cur.Left);
190        if (cur.Right != null) queue.Enqueue(cur.Right);
191      }
192    }
193
194    internal void ToRuleNode() {
195      Parent = null;
196    }
197
198    #region Helpers
199    private double GetEstimatedValue(IDataset dataset, int row) {
200      if (!IsLeaf) return (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right).GetEstimatedValue(dataset, row);
201      if (NodeModel == null) throw new NotSupportedException("M5P has not been built correctly");
202      return NodeModel.GetEstimatedValues(dataset, new[] {row}).First();
203    }
204
205    private void BuildModel(IReadOnlyList<int> rows, IDataset data, IRandom random, ILeafType<IRegressionModel> leafType, CancellationToken cancellation) {
206      var reducedData = ReduceDataset(data, rows);
207      var pd = new RegressionProblemData(reducedData, VariablesUsedForPrediction, TargetVariable);
208      pd.TrainingPartition.Start = 0;
209      pd.TrainingPartition.End = pd.TestPartition.Start = pd.TestPartition.End = reducedData.Rows;
210
211      int noparams;
212      NodeModel = leafType.BuildModel(pd, random, cancellation, out noparams);
213      NodeModelParams = noparams;
214      cancellation.ThrowIfCancellationRequested();
215    }
216
217    private IDataset ReduceDataset(IDataset data, IReadOnlyList<int> rows) {
218      return new Dataset(VariablesUsedForPrediction.Concat(new[] {TargetVariable}), VariablesUsedForPrediction.Concat(new[] {TargetVariable}).Select(x => data.GetDoubleValues(x, rows).ToList()));
219    }
220
221    private static void SplitRows(IReadOnlyList<int> rows, IDataset data, string splitAttr, double splitValue, out IReadOnlyList<int> leftRows, out IReadOnlyList<int> rightRows) {
222      var assignment = data.GetDoubleValues(splitAttr, rows).Select(x => x <= splitValue).ToArray();
223      leftRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => x.b).Select(x => x.i).ToList();
224      rightRows = rows.Zip(assignment, (i, b) => new {i, b}).Where(x => !x.b).Select(x => x.i).ToList();
225    }
226    #endregion
227
228    [StorableClass]
229    private sealed class ConfidenceM5NodeModel : M5NodeModel, IConfidenceRegressionModel {
230      #region HLConstructors
231      [StorableConstructor]
232      private ConfidenceM5NodeModel(bool deserializing) : base(deserializing) { }
233      private ConfidenceM5NodeModel(ConfidenceM5NodeModel original, Cloner cloner) : base(original, cloner) { }
234      public ConfidenceM5NodeModel(string targetAttr) : base(targetAttr) { }
235      public ConfidenceM5NodeModel(M5NodeModel parent) : base(parent) { }
236      public override IDeepCloneable Clone(Cloner cloner) {
237        return new ConfidenceM5NodeModel(this, cloner);
238      }
239      #endregion
240
241      public IEnumerable<double> GetEstimatedVariances(IDataset dataset, IEnumerable<int> rows) {
242        return IsLeaf ? ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, rows) : rows.Select(row => GetEstimatedVariance(dataset, row));
243      }
244
245      private double GetEstimatedVariance(IDataset dataset, int row) {
246        if (!IsLeaf)
247          return ((IConfidenceRegressionModel) (dataset.GetDoubleValue(SplitAttr, row) <= SplitValue ? Left : Right)).GetEstimatedVariances(dataset, row.ToEnumerable()).Single();
248        return ((IConfidenceRegressionModel) NodeModel).GetEstimatedVariances(dataset, new[] {row}).First();
249      }
250
251      public override IRegressionSolution CreateRegressionSolution(IRegressionProblemData problemData) {
252        return new ConfidenceRegressionSolution(this, problemData);
253      }
254    }
255  }
256}
Note: See TracBrowser for help on using the repository browser.